
import torch
import gc
from typing import List, Optional, Dict, Generator
from PIL import Image
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    AutoProcessor, 
    AutoModelForVision2Seq,
    TextIteratorStreamer
)
from threading import Thread

#--------------------------------------------------------------------------------------------
# LLM推論を制御するメインクラス
#--------------------------------------------------------------------------------------------
class TransformersLLM:
    def __init__(
            self,
            model_path: str, 
            use_vision: bool = False, 
            temp: float = 0.7,
            top_p: float = 0.95, 
            max_tokens: int = 512
    ):
        # インターフェースの初期化
        self.llm = TransformersLLMInterface(
            model_path=model_path,
            use_vision=use_vision,
            temp=temp,
            top_p=top_p,
            max_tokens=max_tokens
        )
    
    def _build_message(self, role: str, text: str, images: Optional[List[Image.Image]] = None) -> Dict:
        if images and self.llm.use_vision:
            content = [{"type": "text", "text": text}]
            for _ in images:
                content.append({"type": "image"})
            return {"role": role, "content": content}
        
        return {"role": role, "content": text}

    def respond(
            self,
            user_text: str, 
            user_images: List[str] = None, # 画像パスのリスト
            system_prompt: str = None,
            stream: bool = False):
        
        messages = []
        if system_prompt:
            messages.append(self._build_message(role="system", text=system_prompt))
        
        # 画像の読み込み
        pil_images = [Image.open(p).convert("RGB") for p in (user_images or [])]
        
        messages.append(self._build_message(role="user", text=user_text, images=pil_images))

        # 生成
        generator = self.llm.generate(
            messages=messages, 
            images=pil_images if pil_images else None,
            stream=stream
        )

        if stream:
            for chunk in generator:
                yield chunk
        else:
            yield next(generator)

#--------------------------------------------------------------------------------------------
# Transformers / Torch (MPS) バックエンド
#--------------------------------------------------------------------------------------------
class TransformersLLMInterface:
    def __init__(
        self, 
        model_path: str, 
        use_vision: bool = False, 
        temp: float = 0.7,
        top_p: float = 0.95, 
        max_tokens: int = 512
    ):
        self.model_path = model_path
        self.use_vision = use_vision
        self.temp = temp
        self.top_p = top_p
        self.max_tokens = max_tokens
        
        # Mac (Apple Silicon) 用のデバイス設定
        self.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
        
        self.model = None
        self.tokenizer = None
        self.processor = None

    def _load_model(self):
        dtype = torch.float16 # Macではbfloat16よりfloat16の方が安定する場合が多いです
        
        if self.use_vision:
            # VLM用
            self.model = AutoModelForVision2Seq.from_pretrained(
                self.model_path, torch_dtype=dtype, low_cpu_mem_usage=True
            ).to(self.device)
            self.processor = AutoProcessor.from_pretrained(self.model_path)
            self.tokenizer = self.processor.tokenizer
        else:
            # 通常のLLM用
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path, torch_dtype=dtype, low_cpu_mem_usage=True
            ).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

    def _free_model(self):
        self.model = None
        self.tokenizer = None
        self.processor = None
        gc.collect()
        if torch.backends.mps.is_available():
            torch.mps.empty_cache() # MacのGPUメモリを解放

    def generate(self, messages, images=None, stream=False):
        self._load_model()
        
        try:
            # Chat Templateの適用
            if self.use_vision:
                prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
                inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self.device)
            else:
                prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
                inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

            generation_kwargs = dict(
                **inputs,
                max_new_tokens=self.max_tokens,
                do_sample=True if self.temp > 0 else False,
                temperature=self.temp if self.temp > 0 else None,
                top_p=self.top_p,
            )

            if stream:
                streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
                generation_kwargs["streamer"] = streamer
                
                thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
                thread.start()
                
                for new_text in streamer:
                    yield new_text
            else:
                output_ids = self.model.generate(**generation_kwargs)
                # 入力分をカットしてデコード
                input_len = inputs["input_ids"].shape[1]
                response = self.tokenizer.decode(output_ids[0][input_len:], skip_special_tokens=True)
                yield response
        
        finally:
            self._free_model()