"""会话上下文管理：LLM 对话历史 + Shell Session 绑定。"""

from dataclasses import dataclass, field


@dataclass
class SessionContext:
    """会话上下文：LLM 历史 + Shell Session 绑定。"""

    session_id: str
    user_id: str
    seq: int = 1
    name: str = ""
    llm_messages: list[dict] = field(default_factory=list)

    @property
    def display_name(self) -> str:
        """显示名称：自定义名 > session_id。"""
        return self.name or self.session_id


@dataclass
class UserSessions:
    """用户的所有会话。"""

    user_id: str
    sessions: dict[str, SessionContext] = field(default_factory=dict)
    active_session_id: str = ""
    _next_seq: int = 1

    @property
    def active(self) -> SessionContext | None:
        """当前活跃会话。"""
        return self.sessions.get(self.active_session_id)


class ContextManager:
    """管理所有用户的会话上下文。"""

    def __init__(self) -> None:
        self._users: dict[str, UserSessions] = {}

    def _get_or_create_user(self, user_id: str) -> UserSessions:
        """获取或创建用户会话集合。"""
        if user_id not in self._users:
            self._users[user_id] = UserSessions(user_id=user_id)
        return self._users[user_id]

    def get_or_create(self, user_id: str, session_id: str) -> SessionContext:
        """获取活跃会话，或创建默认会话。"""
        user = self._get_or_create_user(user_id)
        if user.active:
            return user.active

        ctx = SessionContext(
            session_id=session_id,
            user_id=user_id,
            seq=user._next_seq,
        )
        user.sessions[session_id] = ctx
        user.active_session_id = session_id
        user._next_seq += 1
        return ctx

    def new_session(self, user_id: str, session_id: str, seq: int) -> SessionContext:
        """新建会话上下文。"""
        user = self._get_or_create_user(user_id)
        ctx = SessionContext(
            session_id=session_id,
            user_id=user_id,
            seq=seq,
        )
        user.sessions[session_id] = ctx
        user.active_session_id = session_id
        user._next_seq = max(user._next_seq, seq + 1)
        return ctx

    def clear_context(self, user_id: str) -> SessionContext | None:
        """清理活跃会话的 LLM 对话历史。"""
        user = self._users.get(user_id)
        if user is None or user.active is None:
            return None
        user.active.llm_messages = []
        return user.active

    def add_message(self, user_id: str, role: str, content: str) -> None:
        """追加一条 LLM 对话记录到活跃会话。"""
        user = self._users.get(user_id)
        if user is None or user.active is None:
            return
        user.active.llm_messages.append({"role": role, "content": content})

    def get(self, user_id: str) -> SessionContext | None:
        """获取活跃会话上下文。"""
        user = self._users.get(user_id)
        if user is None:
            return None
        return user.active

    def list_sessions(self, user_id: str) -> list[SessionContext]:
        """列出用户的所有会话。"""
        user = self._users.get(user_id)
        if user is None:
            return []
        return list(user.sessions.values())

    def get_active_session_id(self, user_id: str) -> str:
        """获取活跃会话 ID。"""
        user = self._users.get(user_id)
        if user is None:
            return ""
        return user.active_session_id

    def switch_session(self, user_id: str, session_id: str) -> SessionContext | None:
        """切换到指定会话。"""
        user = self._users.get(user_id)
        if user is None or session_id not in user.sessions:
            return None
        user.active_session_id = session_id
        return user.active

    def rename_session(self, user_id: str, session_id: str, new_name: str) -> SessionContext | None:
        """重命名会话。"""
        user = self._users.get(user_id)
        if user is None or session_id not in user.sessions:
            return None
        user.sessions[session_id].name = new_name
        return user.sessions[session_id]

    def delete_session(self, user_id: str, session_id: str) -> bool:
        """删除会话。不能删除最后一个活跃会话。"""
        user = self._users.get(user_id)
        if user is None or session_id not in user.sessions:
            return False
        if len(user.sessions) <= 1:
            return False
        del user.sessions[session_id]
        if user.active_session_id == session_id:
            user.active_session_id = next(iter(user.sessions))
        return True
