import os
import pty
import select
import time
import subprocess
import re
from typing import Optional

class ShellSession:
    """
    Manages a persistent shell (zsh) session using pty.
    """
    def __init__(self, shell_path: str = "/bin/zsh", timeout: float = 5.0):
        self.shell_path = shell_path
        self.timeout = timeout
        self.master_fd = None
        self.slave_fd = None
        self.process = None
        self.buffer = b""
        self.prompt_marker = "__BUTLER_SHELL_PROMPT__"
        self._start()

    def _start(self):
        """Starts the shell process."""
        self.master_fd, self.slave_fd = pty.openpty()
        
        # Start the process
        self.process = subprocess.Popen(
            [self.shell_path],
            stdin=self.slave_fd,
            stdout=self.slave_fd,
            stderr=self.slave_fd,
            preexec_fn=os.setsid,
            close_fds=True,
            env={**os.environ, "PS1": f"{self.prompt_marker}\n", "TERM": "xterm"} 
            # Note: We set PS1 to a unique marker with a newline to make parsing easier
        )
        
        # Close slave in parent
        os.close(self.slave_fd)
        
        # Initial read to clear startup banner and first prompt
        self._read_until_prompt()

    def execute(self, cmd: str, timeout: Optional[float] = None) -> str:
        """
        Executes a command and returns the output.
        """
        if not self.is_alive():
            self._start()

        # Write command
        os.write(self.master_fd, (cmd + "\n").encode('utf-8'))
        
        # Read response
        output = self._read_until_prompt(timeout or self.timeout)
        
        # Cleanup output
        # 1. Remove the echoed command (usually the first line(s))
        # 2. Remove the trailing prompt
        
        clean_output = self._clean_output(cmd, output)
        return clean_output

    def _read_until_prompt(self, timeout: float = 2.0) -> str:
        """Reads from master_fd until prompt is detected or timeout."""
        output = b""
        start_time = time.time()
        
        while True:
            # Check if we have data to read
            r, _, _ = select.select([self.master_fd], [], [], 0.1)
            
            if self.master_fd in r:
                try:
                    chunk = os.read(self.master_fd, 1024)
                    if not chunk:
                        break # EOF
                    output += chunk
                    
                    # Check for prompt
                    if self.prompt_marker.encode() in output:
                        break
                        
                except OSError:
                    break
            
            if time.time() - start_time > timeout:
                # Timeout
                break
                
        return output.decode('utf-8', errors='replace')

    def _clean_output(self, cmd: str, raw_output: str) -> str:
        """Removes echo, prompt, and ANSI escape sequences from raw output."""
        # 1. Remove ANSI escape sequences (terminal colors/formatting)
        ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
        output = ansi_escape.sub('', raw_output)
        
        # 2. Remove prompt
        output = output.replace(self.prompt_marker, "").strip()
        
        # 3. Remove the command echo (simplistic approach: remove first line if it matches cmd)
        lines = output.splitlines()
        if lines and cmd.strip() in lines[0]:
            lines = lines[1:]
        
        return "\n".join(lines).strip()

    def is_alive(self) -> bool:
        return self.process.poll() is None

    def close(self):
        """Terminates the shell."""
        if self.process:
            self.process.terminate()
            self.process.wait()
        if self.master_fd:
            os.close(self.master_fd)

    def __del__(self):
        self.close()
