
import gc
import time
from diffusers_helper.hf_login import login # <<< Re-enabled
import json # <<< Added for preset handling

import os
import random
import re # <<< Re-enabled

# --- HF_HOME 設定 ---
hf_home = os.environ.get('HF_HOME')
if hf_home is None:
    try:
        # スクリプトの場所に基づいてパスを設定
        hf_home_path_1 = os.path.abspath(
            os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))
        )
        os.environ['HF_HOME'] = hf_home_path_1
        print(f"Set HF_HOME env to {os.environ['HF_HOME']}")
    except NameError:
        # __file__ が定義されていない場合 (例: 対話モード)
        print("Warning: '__file__' not defined. Setting HF_HOME relative to current directory.")
        default_hf_home = os.path.abspath('./hf_download') # Use current directory as base
        os.environ['HF_HOME'] = default_hf_home
        print(f"Set HF_HOME env to: {os.environ['HF_HOME']}")
# --- ここまで ---

import gradio as gr
import torch
import traceback
import einops
import safetensors.torch as sf # <<< Re-enabled
import numpy as np
import argparse
import math # <<< Re-enabled


from PIL import Image
from diffusers import AutoencoderKLHunyuanVideo
from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
from diffusers_helper.thread_utils import AsyncStream, async_run
from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
from transformers import SiglipImageProcessor, SiglipVisionModel
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket

import torchvision

# --- コードBからのインポート ---
try:
    from utils.lora_utils import merge_lora_to_state_dict
    from utils.fp8_optimization_utils import optimize_state_dict_with_fp8, apply_fp8_monkey_patch
    print("Successfully imported LoRA/FP8 utilities.") # <<< Added print for consistency
    lora_fp8_available = True # <<< Added for consistency
except ImportError:
    print("Warning: LoRA/FP8 utility functions not found. Please ensure 'utils' directory is present.")
    lora_fp8_available = False # <<< Added for consistency
    # Define dummy functions if utils are missing to avoid NameError, LoRA/FP8 will not work
    def merge_lora_to_state_dict(state_dict, lora_file, multiplier, device):
        print("Warning: merge_lora_to_state_dict dummy function called. LoRA not applied.")
        return state_dict
    def optimize_state_dict_with_fp8(state_dict, device, target_keys, exclude_keys, move_to_device):
        print("Warning: optimize_state_dict_with_fp8 dummy function called. FP8 not applied.")
        return state_dict
    def apply_fp8_monkey_patch(model, state_dict, use_scaled_mm):
        print("Warning: apply_fp8_monkey_patch dummy function called. FP8 not applied.")
        pass
# --- ここまで ---

def save_bcthw_as_png(x, output_filename):
    # UIと合わせる
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), 0, 1) * 255
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename

parser = argparse.ArgumentParser()
parser.add_argument('--share', action='store_true')
parser.add_argument("--server", type=str, default='0.0.0.0')
parser.add_argument("--port", type=int, required=False)
parser.add_argument("--inbrowser", action='store_true')
args = parser.parse_args()

print(args)

free_mem_gb = get_cuda_free_memory_gb(gpu)
high_vram = free_mem_gb > 60

print(f'Free VRAM {free_mem_gb} GB')
print(f'High-VRAM Mode: {high_vram}')

# --- モデルロード (try-exceptでエラーハンドリングを強化) ---
try:
    text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
    text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
    tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
    tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
    vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()

    feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
    image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
except Exception as e:
    print(f"Error loading base models: {e}")
    print("Please ensure you have downloaded the necessary model files and have network connectivity.")
    exit()
# --- ここまで ---

# --- Transformerの遅延ロード ---
transformer = None # 初期状態はNone
transformer_dtype = torch.bfloat16
previous_lora_files = [None] * 5
previous_lora_multipliers = [0.0] * 5
previous_fp8_optimization = False

def load_base_transformer():
    print("Loading base transformer ...")
    try:
        loaded_transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
            'lllyasviel/FramePackI2V_HY', torch_dtype=transformer_dtype
        ).cpu()
        loaded_transformer.eval()
        loaded_transformer.high_quality_fp32_output_for_inference = True
        print('transformer.high_quality_fp32_output_for_inference = True')
        loaded_transformer.to(dtype=transformer_dtype)
        loaded_transformer.requires_grad_(False)
        return loaded_transformer
    except Exception as e:
        print(f"Error loading transformer model: {e}")
        # Gradio UIにエラーを通知する方法があれば追加
        raise # アプリケーションを続行できないため再raise
# --- ここまで ---

vae.eval()
text_encoder.eval()
text_encoder_2.eval()
image_encoder.eval()

if not high_vram:
    vae.enable_slicing()
    vae.enable_tiling()

vae.to(dtype=torch.float16)
image_encoder.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
text_encoder_2.to(dtype=torch.float16)

vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)

if not high_vram:
    DynamicSwapInstaller.install_model(text_encoder, device=gpu)
else:
    text_encoder.to(gpu)
    text_encoder_2.to(gpu)
    image_encoder.to(gpu)
    vae.to(gpu)

stream = AsyncStream()

outputs_folder = './outputs/'
os.makedirs(outputs_folder, exist_ok=True)

# --- loop_workerの引数を更新 (No changes required for presets) ---
def loop_worker(input_image, prompt, n_prompt, generation_count, seed, total_second_length, connection_second_length, padding_second_length, loop_num, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, reduce_file_output, without_preview, output_latent_image, latent_input_file, lora_files, lora_multipliers, fp8_optimization):
    global stream # streamをグローバルとして参照
    try:
        for generation_count_index in range(generation_count):
            # --- 中断チェックをループの最初に追加 ---
            if stream is not None and stream.input_queue.top() == 'end':
                print("Stop requested before generation loop starts/iterates.")
                break # ループを抜ける
            # --- ここまで ---
            current_seed_for_worker = seed # Use the seed from args for the first/only run
            if generation_count != 1:
                current_seed_for_worker = random.randint(0, 2**32 - 1) # Random seed for subsequent runs
            if stream is not None: stream.output_queue.push(('generation count', f"Generation index:{generation_count_index + 1}/{generation_count}, Seed: {current_seed_for_worker}"))

            # worker呼び出しの引数を更新
            worker(input_image, prompt, n_prompt, current_seed_for_worker, total_second_length, connection_second_length, padding_second_length, loop_num, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, reduce_file_output, without_preview, output_latent_image, latent_input_file, lora_files, lora_multipliers, fp8_optimization)

            # --- worker実行後にも中断チェックを追加 ---
            if stream is not None and stream.input_queue.top() == 'end':
                 print("Stop requested after worker completion.")
                 break # ループを抜ける
            # --- ここまで ---

    except Exception as e:
        # loop_worker自体で予期せぬエラーが発生した場合
        print(f"--- Error in loop_worker: {type(e).__name__} ---")
        traceback.print_exc()
        if stream is not None:
            stream.output_queue.push(('error', f"An unexpected error occurred in the generation loop: {e}"))
    finally:
        # ループが正常終了、break、または例外で終了した場合でも、最後に'end'を送る
        if stream is not None:
            stream.output_queue.push(('end', None))
            print("loop_worker finished, pushed 'end'.")
# --- ここまで ---

@torch.no_grad()
# --- workerの引数を更新 (No changes required for presets) ---
def worker(input_image, prompt, n_prompt, seed, total_second_length, connection_second_length, padding_second_length, loop_num, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, reduce_file_output, without_preview, output_latent_image, latent_input_file, lora_files, lora_multipliers, fp8_optimization):
# --- ここまで ---
    global transformer, previous_lora_files, previous_lora_multipliers, previous_fp8_optimization, stream # グローバル変数を参照

    # final_latents needs to be defined early for the finally block
    final_latents = None
    start_latent = None
    # Initialize other history variables that might be used in finally
    history_latents = None
    post_history_latents = None
    history_pixels = None
    post_history_pixels = None
    real_history_latents = None
    post_real_history_latents = None
    final_history_pixels_1loop = None


    # --- 中断チェックを追加 ---
    if stream is not None and stream.input_queue.top() == 'end':
        print("Stop requested at the beginning of worker.")
        # stream.output_queue.push(('end', None)) # loop_workerが最後に送るので不要
        return # workerを早期終了
    # --- ここまで ---

    if reduce_file_output:
        tmp_filename = "system_preview.mp4"

    total_latent_sections = int(max(round((total_second_length * 30) / (latent_window_size * 4)), 1))
    connection_latent_sections = int(max(round((connection_second_length * 30) / (latent_window_size * 4)), 1))

    all_latent_section = total_latent_sections + connection_latent_sections
    print(f"Target sections: Main={total_latent_sections}, Connection={connection_latent_sections}, All={all_latent_section}")


    job_id = generate_timestamp()

    if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))

    # --- メインの try ブロック ---
    try:
        if latent_input_file is None:
            # --- モデル変更チェックとTransformerロード/適用ロジック ---
            model_changed = (transformer is None or
                             lora_files != previous_lora_files or # Compare lists directly
                             lora_multipliers != previous_lora_multipliers or # Compare lists directly
                             (fp8_optimization != previous_fp8_optimization and lora_fp8_available))


            if model_changed:
                 if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Checking model configuration..."))))
                 if transformer is not None:
                     print("Reconfiguring transformer due to settings change...")
                     unload_complete_models(transformer)
                     transformer = None
                     gc.collect()
                     torch.cuda.empty_cache()
                     time.sleep(1.0) # Small delay for resources to free up

                 if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading transformer model..."))))
                 transformer = load_base_transformer() # エラーハンドリング追加済み

                 # --- Apply LoRA and FP8 if needed ---
                 apply_any_lora = any(f is not None and os.path.exists(f) and m != 0.0 for f, m in zip(lora_files, lora_multipliers))
                 apply_fp8 = fp8_optimization and lora_fp8_available

                 if apply_any_lora or apply_fp8:
                     if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading state dict..."))))
                     state_dict = transformer.state_dict()
                     # Ensure state_dict is on CPU for modification to prevent mixed-device errors
                     state_dict = {k: v.cpu() for k, v in state_dict.items()}
                     gc.collect()

                     if apply_any_lora:
                         print("Applying LoRAs...")
                         for i, (lora_file, lora_multiplier) in enumerate(zip(lora_files, lora_multipliers)):
                             if lora_file is not None and os.path.exists(lora_file) and lora_multiplier != 0.0:
                                 lora_name = os.path.basename(lora_file)
                                 if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, f"Applying LoRA {i+1}: {lora_name}..."))))
                                 print(f"Merging LoRA file {lora_name} (Index {i}) with multiplier {lora_multiplier}...")
                                 try:
                                     # Assuming merge_lora_to_state_dict can handle CPU state_dict and place output on 'device'
                                     state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu) # target device gpu
                                     state_dict = {k: v.cpu() for k, v in state_dict.items()} # move back to CPU for next LoRA or FP8
                                     gc.collect()
                                 except Exception as e:
                                     print(f"Error merging LoRA {lora_name} (Index {i}): {e}")
                                     traceback.print_exc()
                                     if stream is not None: stream.output_queue.push(("error", f"Error merging LoRA {lora_name}: {e}"))
                                     # Optionally, decide if to continue or raise based on severity
                                     continue # Skip this LoRA and try others
                             else:
                                 print(f"Skipping LoRA at index {i}: File invalid or multiplier zero.")
                         print("Finished applying LoRAs.")


                     if apply_fp8:
                         if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Applying FP8 Optimization..."))))
                         TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
                         EXCLUDE_KEYS = ["norm"]
                         print("Optimizing state dict for fp8...")
                         try:
                            # Ensure state_dict is on CPU before FP8 optimization
                            state_dict = {k: v.cpu() for k, v in state_dict.items()}
                            gc.collect()
                            # optimize_state_dict_with_fp8 assumes input state_dict is on CPU, output might be on device or CPU
                            state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False) # output on CPU
                            gc.collect()
                            print("Applying FP8 monkey patch...")
                            apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False) # This might move parts of model to GPU
                            gc.collect()
                         except Exception as e:
                             print(f"Error applying FP8 optimization: {e}")
                             traceback.print_exc()
                             if stream is not None: stream.output_queue.push(("error", f"Error applying FP8 optimization: {e}"))
                             raise e # FP8 error might be critical

                     if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading state dict into transformer..."))))
                     # Ensure state_dict is on CPU before loading into CPU model
                     state_dict = {k: v.cpu() for k, v in state_dict.items()}
                     gc.collect()
                     info = transformer.load_state_dict(state_dict, strict=True, assign=True) # strict=True if all keys should match
                     print(f"LoRA/FP8 state dict loaded info: {info}")
                     del state_dict
                     gc.collect()

                 if not high_vram:
                     print("Installing DynamicSwap for transformer...")
                     DynamicSwapInstaller.install_model(transformer, device=gpu)
                 else:
                     print("Moving configured transformer to GPU...")
                     transformer.to(gpu)

                 previous_lora_files = lora_files[:]
                 previous_lora_multipliers = lora_multipliers[:]
                 previous_fp8_optimization = fp8_optimization
                 if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Transformer configured."))))
                 gc.collect()
                 torch.cuda.empty_cache()

            elif transformer is None: # First run, no changes needed or available
                 if stream is not None: stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading base transformer..."))))
                 transformer = load_base_transformer()
                 if not high_vram:
                     print("Installing DynamicSwap for transformer...")
                     DynamicSwapInstaller.install_model(transformer, device=gpu)
                 else:
                     print("Moving base transformer to GPU...")
                     transformer.to(gpu)
                 previous_lora_files = lora_files[:]
                 previous_lora_multipliers = lora_multipliers[:]
                 previous_fp8_optimization = fp8_optimization
                 gc.collect()
                 torch.cuda.empty_cache()
            # --- ここまで Transformerロード/適用ロジック ---

            if not high_vram:
                unload_complete_models(text_encoder_2, image_encoder, vae) # Unload others before loading text_encoder
                fake_diffusers_current_device(text_encoder, gpu) # Ensure text_encoder is on GPU for DynamicSwap
            gc.collect()
            torch.cuda.empty_cache()

            # Text encoding
            if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
            # --- 中断チェック ---
            if stream is not None and stream.input_queue.top() == 'end': raise KeyboardInterrupt("Stop requested by user.")
            if not high_vram:
                print("Loading text encoder 2 (Low VRAM)...")
                load_model_as_complete(text_encoder_2, target_device=gpu)
            else:
                print("Ensuring text encoders are on GPU (High VRAM)...")
                text_encoder.to(gpu)
                text_encoder_2.to(gpu)

            llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
            if cfg == 1: # Assuming cfg is numeric
                print("Using zero embeddings for unconditional guidance (CFG=1)")
                llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
            else:
                print("Encoding negative prompt...")
                llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

            llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
            llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)

            # Processing input image
            if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
             # --- 中断チェック ---
            if stream is not None and stream.input_queue.top() == 'end': raise KeyboardInterrupt("Stop requested by user.")
            H, W, C = input_image.shape
            height, width = find_nearest_bucket(H, W, resolution=640)
            print(f"Input image HWC: {H, W, C}. Resized/Cropped to: {height, width}")
            input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
            if not without_preview:
                Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}_{seed}.png'))
            input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1.0
            input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]

            # VAE encoding
            if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
             # --- 中断チェック ---
            if stream is not None and stream.input_queue.top() == 'end': raise KeyboardInterrupt("Stop requested by user.")
            if not high_vram:
                print("Loading VAE (Low VRAM)...")
                unload_complete_models(text_encoder, text_encoder_2) # Unload text encoders
                load_model_as_complete(vae, target_device=gpu)
            else:
                print("Ensuring VAE is on GPU (High VRAM)...")
                vae.to(gpu)

            print(f"Encoding image to latent space... Input shape: {input_image_pt.shape}")
            start_latent = vae_encode(input_image_pt, vae)
            print(f"Initial latent created. Shape: {start_latent.shape}, Device: {start_latent.device}")


            # CLIP Vision
            if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
             # --- 中断チェック ---
            if stream is not None and stream.input_queue.top() == 'end': raise KeyboardInterrupt("Stop requested by user.")
            if not high_vram:
                print("Loading Image Encoder (Low VRAM)...")
                unload_complete_models(vae) # Unload VAE
                load_model_as_complete(image_encoder, target_device=gpu)
            else:
                print("Ensuring Image Encoder is on GPU (High VRAM)...")
                image_encoder.to(gpu)

            print("Encoding image with CLIP Vision...")
            image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
            image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
            print(f"CLIP Vision embedding created. Shape: {image_encoder_last_hidden_state.shape}")


            # Dtype
            target_dtype = transformer.dtype
            print(f"Converting embeddings to target dtype: {target_dtype}")
            llama_vec = llama_vec.to(target_dtype)
            llama_vec_n = llama_vec_n.to(target_dtype)
            clip_l_pooler = clip_l_pooler.to(target_dtype)
            clip_l_pooler_n = clip_l_pooler_n.to(target_dtype)
            image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(target_dtype)

            start_latent = start_latent.to(dtype=torch.float32) # Ensure start_latent is float32 for history
            print(f"Start latent converted to float32. Shape: {start_latent.shape}, Device: {start_latent.device}")

            # Sampling
            if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
            rnd = torch.Generator("cpu").manual_seed(seed)
            print(f"Using seed: {seed}")
            num_frames = latent_window_size * 4 - 3

            ##メイン作成
            history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32).cpu()
            history_pixels = None
            total_generated_latent_frames = 0
            latent_paddings = [i for i in reversed(range(total_latent_sections))]
            if total_latent_sections > 4:
                latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
            print(f"Main section padding sequence: {latent_paddings}")

            for section_idx, latent_padding in enumerate(latent_paddings): # Add section_idx for logging
                # --- 中断チェック (ループの先頭) ---
                if stream is not None and stream.input_queue.top() == 'end':
                    print("Stop requested before main sampling loop iteration.")
                    raise KeyboardInterrupt("Stop requested by user.")
                # --- ここまで ---

                is_last_section = latent_padding == 0
                latent_padding_init_size = int(padding_second_length * latent_window_size)
                latent_padding_size = (latent_padding * latent_window_size) + latent_padding_init_size
                print(f'--- Main Section {section_idx + 1}/{len(latent_paddings)} (Padding: {latent_padding}, Size: {latent_padding_size}, Is Last: {is_last_section}) ---')

                indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
                clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
                clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)

                print(f"Preparing main clean latents from history shape: {history_latents.shape}")
                clean_latents_pre = start_latent.to(history_latents) # Ensure same device/dtype for cat
                clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
                clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
                print(f"Main context shapes: 1x={clean_latents.shape}, 2x={clean_latents_2x.shape}, 4x={clean_latents_4x.shape}")


                if not high_vram:
                    print("Loading Transformer (Low VRAM)...")
                    unload_complete_models() # Unload all non-transformer models
                    move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
                    gc.collect()
                    torch.cuda.empty_cache()
                else:
                    print("Ensuring Transformer is on GPU (High VRAM)...")
                    transformer.to(gpu)


                if use_teacache:
                    print("Initializing TeaCache...")
                    transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
                else:
                    print("Disabling TeaCache...")
                    transformer.initialize_teacache(enable_teacache=False)

                # --- callback内でKeyboardInterruptをraiseするように変更 ---
                def callback(d):
                    preview_latent = d['denoised'] # Changed from 'preview'
                    preview = vae_decode_fake(preview_latent) # Changed from 'preview'
                    preview = (preview.float() * 255.0).clamp(0, 255) # Ensure float for multiplication
                    preview = preview.detach().cpu().numpy().astype(np.uint8) # clip(0, 255) before astype
                    preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
                    # --- 中断チェック: streamが存在し、'end'がキューにあるか確認 ---
                    if stream is not None and stream.input_queue.top() == 'end':
                        # stream.output_queue.push(('end', None)) # ここでは送らない
                        raise KeyboardInterrupt('User ends the task.') # 例外を発生させる
                    # --- ここまで ---
                    current_step = d['i'] + 1
                    percentage = int(100.0 * current_step / steps)
                    hint = f'Sampling {current_step}/{steps}'
                    desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). Generating main video...'
                    # streamがNoneでないことを確認してからpush
                    if stream is not None:
                         stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
                    return
                # --- ここまで ---

                print("Starting K-diffusion sampling for main section...")
                generated_latents = sample_hunyuan(
                    transformer=transformer, sampler='unipc', width=width, height=height, frames=num_frames,
                    real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs,
                    num_inference_steps=steps, generator=rnd, prompt_embeds=llama_vec.to(gpu), prompt_embeds_mask=llama_attention_mask.to(gpu), # Ensure on GPU
                    prompt_poolers=clip_l_pooler.to(gpu), negative_prompt_embeds=llama_vec_n.to(gpu), negative_prompt_embeds_mask=llama_attention_mask_n.to(gpu),
                    negative_prompt_poolers=clip_l_pooler_n.to(gpu), device=gpu, dtype=transformer.dtype,
                    image_embeddings=image_encoder_last_hidden_state.to(gpu), latent_indices=latent_indices.to(gpu),
                    clean_latents=clean_latents.to(gpu, dtype=transformer.dtype), clean_latent_indices=clean_latent_indices.to(gpu), # Ensure on GPU and correct dtype
                    clean_latents_2x=clean_latents_2x.to(gpu, dtype=transformer.dtype), clean_latent_2x_indices=clean_latent_2x_indices.to(gpu),
                    clean_latents_4x=clean_latents_4x.to(gpu, dtype=transformer.dtype), clean_latent_4x_indices=clean_latent_4x_indices.to(gpu),
                    callback=callback, # ここで KeyboardInterrupt が発生する可能性
                )
                print(f"Sampling finished. Generated latents shape: {generated_latents.shape}")


                total_generated_latent_frames += int(generated_latents.shape[2])
                history_latents = torch.cat([generated_latents.cpu().to(history_latents.dtype), history_latents], dim=2) # Move to CPU, match dtype
                real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
                print(f"Real history latents shape after main section: {real_history_latents.shape}")


                if not without_preview:
                     # --- 中断チェック ---
                    if stream is not None and stream.input_queue.top() == 'end': raise KeyboardInterrupt("Stop requested by user.")
                    if not high_vram:
                        print("Loading VAE for decoding (Low VRAM)...")
                        offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation) # Use correct preserved_memory_gb
                        load_model_as_complete(vae, target_device=gpu)
                    else:
                        print("Ensuring VAE is on GPU for decoding (High VRAM)...")
                        vae.to(gpu)


                    section_latent_frames_to_decode = latent_window_size * 2
                    overlapped_frames = latent_window_size * 4 - 3
                    # Decode only the newly generated part (or a relevant window)
                    current_latents_to_decode = real_history_latents[:, :, :section_latent_frames_to_decode] # Decode from the start of real history

                    print(f"Decoding main section latents of shape: {current_latents_to_decode.shape}")
                    if current_latents_to_decode.shape[2] > 0:
                        current_pixels = vae_decode(current_latents_to_decode.to(gpu, dtype=vae.dtype), vae).cpu() # Ensure on GPU with VAE dtype
                        if history_pixels is None:
                            history_pixels = current_pixels
                            print(f"Initial pixels decoded. Shape: {history_pixels.shape}")
                        else:
                            print(f"Soft appending main pixels: current shape {current_pixels.shape}, history shape {history_pixels.shape}, overlap {overlapped_frames}")
                            history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
                            print(f"History pixels shape after append: {history_pixels.shape}")
                    else:
                         print("Skipping VAE decode for main section as there are no frames in current_latents_to_decode.")


                    if not high_vram:
                        print("Unloading VAE after decoding (Low VRAM)...")
                        unload_complete_models(vae)
                        gc.collect()
                        torch.cuda.empty_cache()


                    if history_pixels is not None and history_pixels.shape[2] > 0:
                        if reduce_file_output:
                            output_filename = os.path.join(outputs_folder, tmp_filename)
                        else:
                            output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}_{seed}.mp4')
                        print(f"Saving intermediate main video to {output_filename}...")
                        save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf)
                        print(f'Main Decoded. Latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
                        # streamがNoneでないことを確認してからpush
                        if stream is not None:
                             stream.output_queue.push(('file', output_filename))
                    else:
                        print("Skipping saving main video as history_pixels is None or empty.")


                if is_last_section:
                    break

            ##コネクション作成
            post_history_latents = real_history_latents # Use the complete history from main generation
            post_history_pixels = history_pixels
            post_total_generated_latent_frames = total_generated_latent_frames
            print(f"Generate connection video starting from latent shape: {post_history_latents.shape}")

            latent_paddings = [i for i in reversed(range(connection_latent_sections))]
            if connection_latent_sections > 4:
                 latent_paddings = [3] + [2] * (connection_latent_sections - 3) + [1, 0]
            print(f"Connection section padding sequence: {latent_paddings}")


            if total_latent_sections + connection_latent_sections > 2: N = 16
            elif total_latent_sections + connection_latent_sections == 2: N = 15
            else: N=6
            print(f"Using N={N} for connection sampling.")

            for section_idx, latent_padding in enumerate(latent_paddings): # Add section_idx for logging
                # --- 中断チェック (ループの先頭) ---
                if stream is not None and stream.input_queue.top() == 'end':
                    print("Stop requested before connection sampling loop iteration.")
                    raise KeyboardInterrupt("Stop requested by user.")
                # --- ここまで ---

                is_last_section = latent_padding == 0
                latent_padding_size = latent_padding * latent_window_size
                print(f'--- Connection Section {section_idx + 1}/{len(latent_paddings)} (Padding: {latent_padding}, Size: {latent_padding_size}, Is Last: {is_last_section}) ---')

                indices = torch.arange(0, sum([1,latent_padding_size, latent_window_size, 1, 2, N])).unsqueeze(0)
                clean_latent_indices_pre,blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = \
                    indices.split([1,latent_padding_size, latent_window_size, 1, 2, N], dim=1)
                clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)

                print(f"Preparing connection clean latents from history shape: {post_history_latents.shape}, needing 1+2+{N}={3+N} frames for context")
                required_len_for_split = 1 + 2 + N
                available_len = post_history_latents.shape[2]

                clean_latents_pre  = post_history_latents[:, :, -1:, :, :] # Last frame of current history

                if available_len < required_len_for_split:
                    print(f"Warning: Connection context requires {required_len_for_split} frames, but only {available_len} available. Padding with the start_latent.")
                    # Get what's available from the beginning of post_history_latents
                    context_latents_to_split = post_history_latents[:, :, :available_len, :, :]
                    padding_needed = required_len_for_split - available_len
                    # Pad with start_latent (which should be on CPU, float32)
                    padding_tensor = start_latent.to(context_latents_to_split.device, dtype=context_latents_to_split.dtype).repeat(1, 1, padding_needed, 1, 1)
                    # Pad at the beginning (older part) of the context_latents_to_split
                    context_latents_to_split = torch.cat([padding_tensor, context_latents_to_split], dim=2)
                    print(f"Padded context tensor shape for split: {context_latents_to_split.shape}")
                else:
                    # Take from the beginning of post_history_latents
                    context_latents_to_split = post_history_latents[:, :, :required_len_for_split, :, :]

                # Now split the (potentially padded) context tensor
                clean_latents_post, clean_latents_2x, clean_latents_4x = context_latents_to_split.split([1, 2, N], dim=2)
                clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine last frame with the first from context
                print(f"Connection context shapes: 1x={clean_latents.shape}, 2x={clean_latents_2x.shape}, 4x={clean_latents_4x.shape}")


                if not high_vram:
                    print("Loading Transformer (Low VRAM)...")
                    unload_complete_models()
                    move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
                    gc.collect()
                    torch.cuda.empty_cache()
                else:
                    print("Ensuring Transformer is on GPU (High VRAM)...")
                    transformer.to(gpu)


                if use_teacache:
                    print("Initializing TeaCache for connection...")
                    transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
                else:
                    print("Disabling TeaCache for connection...")
                    transformer.initialize_teacache(enable_teacache=False)

                # --- callback_conn内でKeyboardInterruptをraiseするように変更 ---
                def callback_conn(d): # 異なる名前を使用
                    preview_latent = d['denoised'] # Changed
                    preview = vae_decode_fake(preview_latent) # Changed
                    preview = (preview.float() * 255.0).clamp(0,255) # Ensure float
                    preview = preview.detach().cpu().numpy().astype(np.uint8) # clip before astype
                    preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
                     # --- 中断チェック: streamが存在し、'end'がキューにあるか確認 ---
                    if stream is not None and stream.input_queue.top() == 'end':
                        # stream.output_queue.push(('end', None)) # ここでは送らない
                        raise KeyboardInterrupt('User ends the task.') # 例外を発生させる
                    # --- ここまで ---
                    current_step = d['i'] + 1
                    percentage = int(100.0 * current_step / steps)
                    hint = f'Sampling {current_step}/{steps}'
                    desc = f'Total generated frames: {int(max(0, post_total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (post_total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). Generating connection video...'
                    # streamがNoneでないことを確認してからpush
                    if stream is not None:
                         stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
                    return
                # --- ここまで ---

                print("Starting K-diffusion sampling for connection section...")
                generated_latents = sample_hunyuan(
                    transformer=transformer, sampler='unipc', width=width, height=height, frames=num_frames,
                    real_guidance_scale=cfg, distilled_guidance_scale=gs, guidance_rescale=rs,
                    num_inference_steps=steps, generator=rnd, prompt_embeds=llama_vec.to(gpu), prompt_embeds_mask=llama_attention_mask.to(gpu),
                    prompt_poolers=clip_l_pooler.to(gpu), negative_prompt_embeds=llama_vec_n.to(gpu), negative_prompt_embeds_mask=llama_attention_mask_n.to(gpu),
                    negative_prompt_poolers=clip_l_pooler_n.to(gpu), device=gpu, dtype=transformer.dtype,
                    image_embeddings=image_encoder_last_hidden_state.to(gpu), latent_indices=latent_indices.to(gpu),
                    clean_latents=clean_latents.to(gpu, dtype=transformer.dtype), clean_latent_indices=clean_latent_indices.to(gpu),
                    clean_latents_2x=clean_latents_2x.to(gpu, dtype=transformer.dtype), clean_latent_2x_indices=clean_latent_2x_indices.to(gpu),
                    clean_latents_4x=clean_latents_4x.to(gpu, dtype=transformer.dtype), clean_latent_4x_indices=clean_latent_4x_indices.to(gpu),
                    callback=callback_conn, # ここで KeyboardInterrupt が発生する可能性
                )
                print(f"Connection sampling finished. Generated latents shape: {generated_latents.shape}")


                post_total_generated_latent_frames += int(generated_latents.shape[2])
                post_history_latents = torch.cat([generated_latents.cpu().to(post_history_latents.dtype), post_history_latents], dim=2) # Move to CPU, match dtype
                post_real_history_latents = post_history_latents[:, :, :post_total_generated_latent_frames, :, :]
                print(f"Post real history latents shape after connection section: {post_real_history_latents.shape}")


                if not without_preview:
                     # --- 中断チェック ---
                    if stream is not None and stream.input_queue.top() == 'end': raise KeyboardInterrupt("Stop requested by user.")
                    if not high_vram:
                        print("Loading VAE for connection decoding (Low VRAM)...")
                        offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
                        load_model_as_complete(vae, target_device=gpu)
                    else:
                        print("Ensuring VAE is on GPU for connection decoding (High VRAM)...")
                        vae.to(gpu)


                    section_latent_frames_to_decode = latent_window_size * 2
                    overlapped_frames = latent_window_size * 4 - 3
                    current_latents_to_decode = post_real_history_latents[:, :, :section_latent_frames_to_decode]

                    print(f"Decoding connection section latents of shape: {current_latents_to_decode.shape}")
                    if current_latents_to_decode.shape[2] > 0:
                         current_pixels = vae_decode(current_latents_to_decode.to(gpu, dtype=vae.dtype), vae).cpu()
                         if post_history_pixels is None:
                             print("Warning: post_history_pixels is None during connection VAE decode. Initializing.")
                             post_history_pixels = current_pixels
                             print(f"Initial connection pixels decoded. Shape: {post_history_pixels.shape}")
                         else:
                             print(f"Soft appending connection pixels: current shape {current_pixels.shape}, history shape {post_history_pixels.shape}, overlap {overlapped_frames}")
                             post_history_pixels = soft_append_bcthw(current_pixels, post_history_pixels, overlapped_frames)
                             print(f"Post history pixels shape after append: {post_history_pixels.shape}")
                    else:
                         print("Skipping VAE decode for connection section as there are no frames in current_latents_to_decode.")


                    if not high_vram:
                        print("Unloading VAE after connection decoding (Low VRAM)...")
                        unload_complete_models(vae)
                        gc.collect()
                        torch.cuda.empty_cache()


                    if post_history_pixels is not None and post_history_pixels.shape[2] > 0:
                        if reduce_file_output:
                            output_filename = os.path.join(outputs_folder, tmp_filename)
                        else:
                            output_filename = os.path.join(outputs_folder, f'{job_id}_{post_total_generated_latent_frames}_{seed}_post.mp4')
                        print(f"Saving intermediate connection video to {output_filename}...")
                        save_bcthw_as_mp4(post_history_pixels, output_filename, fps=30, crf=mp4_crf)
                        print(f'Connection Decoded. Latent shape {post_real_history_latents.shape}; pixel shape {post_history_pixels.shape}')
                        # streamがNoneでないことを確認してからpush
                        if stream is not None:
                             stream.output_queue.push(('file', output_filename))
                    else:
                        print("Skipping saving connection video as post_history_pixels is None or empty.")


                if is_last_section:
                    break

            # --- Start of Code A block replacement ---
            #1ループ作成
            if 'post_real_history_latents' in locals() and 'real_history_latents' in locals() and \
               post_real_history_latents is not None and real_history_latents is not None:

                target_connection_len = latent_window_size * connection_latent_sections
                target_main_len = latent_window_size * total_latent_sections

                # Ensure slicing does not go out of bounds
                connection_history_latents_for_loop = post_real_history_latents[:, :, :min(target_connection_len, post_real_history_latents.shape[2]), :, :]
                main_history_latents_for_loop = real_history_latents[:, :, :min(target_main_len, real_history_latents.shape[2]), :, :]

                print(f"Sliced history shapes for loop: Connection={connection_history_latents_for_loop.shape}, Main={main_history_latents_for_loop.shape}")

                # Check if there are enough frames for concatenation
                can_concat = (connection_history_latents_for_loop.shape[2] >= latent_window_size and
                              main_history_latents_for_loop.shape[2] >= latent_window_size and
                              connection_history_latents_for_loop.shape[2] > 0 and # Ensure not empty
                              main_history_latents_for_loop.shape[2] > 0) # Ensure not empty


                if can_concat:
                    final_latents = torch.cat([
                        connection_history_latents_for_loop[:, :, -latent_window_size:, :, :], # Last LWS frames of connection
                        main_history_latents_for_loop,                                        # Full main history
                        connection_history_latents_for_loop,                                  # Full connection history
                        main_history_latents_for_loop[:, :, -latent_window_size:, :, :]       # Last LWS frames of main
                    ], dim=2)
                    print(f"Final combined latent for loop shape: {final_latents.shape}")

                    if output_latent_image:
                         # --- 中断チェック ---
                        if stream is not None and stream.input_queue.top() == 'end': raise KeyboardInterrupt("Stop requested by user.")

                        # Combine main and connection for PNG preview (as in original logic)
                        to_pixel_latents = torch.cat([main_history_latents_for_loop,
                                                       connection_history_latents_for_loop], dim=2)
                        print(f"Shape for latent PNG decode: {to_pixel_latents.shape}")

                        if to_pixel_latents.shape[2] > 0:
                             if not high_vram:
                                 print("Loading VAE for latent PNG (Low VRAM)...")
                                 load_model_as_complete(vae, target_device=gpu)
                             else:
                                 vae.to(gpu)
                             to_pixel_latents_png = vae_decode_fake(to_pixel_latents.to(gpu, dtype=vae.dtype), vae).cpu() # ensure GPU, VAE dtype
                             if not high_vram:
                                 print("Unloading VAE after latent PNG (Low VRAM)...")
                                 unload_complete_models(vae)


                             output_filename_png = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}_{seed}_latent.png')
                             save_bcthw_as_png(to_pixel_latents_png, output_filename_png)
                             if stream is not None: stream.output_queue.push(('message', f"Latent preview saved: {os.path.basename(output_filename_png)}"))
                        else:
                            print("Warning: Skipping latent PNG generation as concatenated latents for PNG are empty.")


                        output_filename_pt = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}_{seed}_latent.pt')
                        torch.save(final_latents.cpu(), output_filename_pt) # Ensure saving CPU tensor
                        if stream is not None: stream.output_queue.push(('message', f"Latent file saved: {os.path.basename(output_filename_pt)}"))
                        # --- latent保存後はworkerを終了 ---
                        print("Worker finished after saving latent file.")
                        return # Exit worker
                        # --- ここまで ---
                else:
                    print("Warning: Not enough frames in history latents to create final loop tensor. Skipping loop generation and latent saving.")
                    if stream is not None: stream.output_queue.push(('warning', "Loop generation & latent saving skipped: Insufficient historical frames."))
                    final_latents = None # Mark as None if failed

            else:
                print("Warning: History latents required for loop creation are missing. Skipping loop generation and latent saving.")
                if stream is not None: stream.output_queue.push(('warning', "Loop generation & latent saving skipped: Missing history latents."))
                final_latents = None # Mark as None if failed
            # --- End of first part of Code A block ---

        # --- Start of second part of Code A block (else for latent_input_file is None) ---
        else: # This 'else' corresponds to 'if latent_input_file is None:'
            print(f"Loading latent file: {latent_input_file}")
            if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, f'Loading latent file {os.path.basename(latent_input_file)}...'))))
            try:
                 final_latents = torch.load(latent_input_file, map_location='cpu') # Load to CPU first
                 print(f"Loaded latent tensor with shape: {final_latents.shape}")
            except FileNotFoundError:
                  print(f"Error: Latent file not found at {latent_input_file}")
                  if stream is not None: stream.output_queue.push(('error', f"Latent file not found: {os.path.basename(latent_input_file)}"))
                  return # エラー時はworker終了
            except Exception as e:
                  print(f"Error loading latent file: {e}")
                  traceback.print_exc()
                  if stream is not None: stream.output_queue.push(('error', f"Error loading latent file: {e}"))
                  return # エラー時はworker終了

            filename = os.path.basename(latent_input_file)
            match = re.match(r"(\d+_\d+_\d+_\d+)_(\d+)_(\d+)_latent\.pt", filename)
            if match:
                 job_id = match.group(1)
                 total_generated_latent_frames_from_file = match.group(2) # Store separately
                 seed_from_file = match.group(3) # Store separately
                 print(f"Parsed from filename: job_id={job_id}, frames={total_generated_latent_frames_from_file}, seed={seed_from_file}")
                 # Use parsed seed if generation_count was 1, otherwise worker's random seed is already set
                 if generation_count == 1: # Assuming generation_count is accessible or passed
                     seed = int(seed_from_file) # Update the main seed variable
            else:
                 print("Warning: Could not parse info from latent filename. Using current job_id and seed.")
                 total_generated_latent_frames_from_file = 'unknown' # Indicate unknown count
                 # job_id and seed will be the ones generated at the start of this worker call

            # For loop decode, all_latent_section should be based on UI sliders
            # as these define the *intended* structure of the video from the latent
            total_latent_sections_from_ui = int(max(round((total_second_length * 30) / (latent_window_size * 4)), 1))
            connection_latent_sections_from_ui = int(max(round((connection_second_length * 30) / (latent_window_size * 4)), 1))
            all_latent_section = total_latent_sections_from_ui + connection_latent_sections_from_ui # Used for loop count (MAX)
            print(f"Using section lengths from UI for latent file decode loop iterations: all={all_latent_section}")
        # --- End of second part of Code A block ---

        # --- Start of third part of Code A block (loop decode) ---
        # Check if final_latents was successfully created or loaded
        if final_latents is not None:
            # --- デコードループ用の try...finally を追加 ---
            decode_loop_final_latents_gpu = None # For cleanup
            try:
                print(f"Starting loop decode with final_latents shape: {final_latents.shape}")
                if stream is not None: stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Preparing for loop decode...'))))


                # VAEとLatentをGPUに移動 (tryブロック内)
                if not high_vram:
                    print("Loading VAE for loop decode (Low VRAM)...")
                    # Ensure transformer is offloaded if it exists
                    if transformer is not None: # Check if transformer was loaded
                        offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation) # Use correct var
                    load_model_as_complete(vae, target_device=gpu)
                else:
                    print("Ensuring VAE is on GPU for loop decode (High VRAM)...")
                    if 'vae' in locals() and vae is not None: # Check if vae is loaded
                         vae.to(gpu)
                    else:
                         raise RuntimeError("VAE model not loaded for loop decoding (High VRAM).")


                print(f"Moving final_latents ({final_latents.shape}) to GPU for decoding...")
                decode_loop_final_latents_gpu = final_latents.to(gpu, dtype=vae.dtype) # Move to GPU, match VAE dtype


                final_history_pixels = None
                MAX = all_latent_section + 2 # Code A logic for loop iterations
                pixel_map = dict()
                print(f"Loop decode iterations (MAX): {MAX}")


                for i in range(MAX):
                    # --- 中断チェック (デコードループの先頭) ---
                    if stream is not None and stream.input_queue.top() == 'end':
                        print("Stop requested during decode loop.")
                        raise KeyboardInterrupt("Stop requested by user.") # メインハンドラでキャッチ
                    # --- ここまで ---

                    pixel_map_key = i % all_latent_section
                    latent_index = (all_latent_section - 1 - pixel_map_key) # Code A logic
                    latent_offset = latent_index * latent_window_size # Use latent_window_size from UI (should be 9)

                    percentage = int(100.0 * (i + 1) / MAX) # Progress update based on current iter
                    hint = f'Make 1 loop {i+1}/{MAX}' # Code A hint style
                    desc = f'Now making 1 loop decoding' # Code A description
                    if stream is not None: stream.output_queue.push(('progress', (None, desc, make_progress_bar_html(percentage, hint))))

                    section_latent_frames = latent_window_size * 2 # As used in Code A's else block
                    overlapped_frames = latent_window_size * 4 - 3 # Code A overlap value

                    decode_slice_start = latent_offset
                    # For initial decode, take LWS frames. For subsequent, take section_latent_frames.
                    decode_slice_end_initial = decode_slice_start + latent_window_size
                    decode_slice_end_section = decode_slice_start + section_latent_frames

                    # Boundary checks
                    if decode_slice_start >= decode_loop_final_latents_gpu.shape[2]:
                        print(f"Warning: Skipping iter {i}, latent_offset {latent_offset} is out of bounds for tensor of length {decode_loop_final_latents_gpu.shape[2]}.")
                        continue

                    if final_history_pixels is None:
                        # Initial decode
                        actual_slice_end = min(decode_slice_end_initial, decode_loop_final_latents_gpu.shape[2])
                        if decode_slice_start < actual_slice_end: # Ensure valid slice
                            print(f"Loop decode iter {i}: Initial decode slice [{decode_slice_start}:{actual_slice_end}]")
                            latents_to_decode = decode_loop_final_latents_gpu[:, :, decode_slice_start:actual_slice_end, :, :]
                            final_history_pixels = vae_decode(latents_to_decode, vae).cpu()
                            # pixel_map[latent_index] = final_history_pixels # Code A saves the first full decode too
                            print(f"Initial loop pixels decoded. Shape: {final_history_pixels.shape}")
                        else:
                             print(f"Warning: Skipping initial decode at iter {i}, invalid slice [{decode_slice_start}:{actual_slice_end}]")
                             continue # Skip to next iteration if slice is invalid
                    else:
                        # Subsequent decodes
                        current_pixels = pixel_map.get(latent_index)
                        if current_pixels is None:
                             actual_slice_end = min(decode_slice_end_section, decode_loop_final_latents_gpu.shape[2])
                             if decode_slice_start < actual_slice_end: # Ensure valid slice
                                print(f"Loop decode iter {i}: Decoding new section slice [{decode_slice_start}:{actual_slice_end}] for index {latent_index}")
                                latents_to_decode = decode_loop_final_latents_gpu[:, :, decode_slice_start:actual_slice_end]
                                current_pixels = vae_decode(latents_to_decode, vae).cpu()
                                pixel_map[latent_index] = current_pixels
                                print(f"Decoded section {latent_index}. Shape: {current_pixels.shape}")
                             else:
                                print(f"Warning: Skipping decode at iter {i}, invalid slice [{decode_slice_start}:{actual_slice_end}] for section {latent_index}.")
                                # current_pixels remains None

                        if current_pixels is not None:
                            print(f"Soft appending loop pixels: current (idx {latent_index}) shape {current_pixels.shape}, history shape {final_history_pixels.shape}, overlap {overlapped_frames}")
                            final_history_pixels = soft_append_bcthw(current_pixels, final_history_pixels, overlapped_frames)
                            print(f"Final history pixels shape after append: {final_history_pixels.shape}")
                        # else:
                        #    print(f"Skipping append for iter {i} as current_pixels for index {latent_index} is None (or slice was invalid).")


                # --- デコードループ後の後処理 ---
                if final_history_pixels is not None and final_history_pixels.shape[2] > 0:
                    start_trim = latent_window_size * 4
                    end_trim_offset = latent_window_size * 4 - 3 # This is the number of frames to remove from the end
                    print(f"Final pixel history shape before Code A slicing: {final_history_pixels.shape}")
                    print(f"Applying Code A slicing: start_trim={start_trim}, end_trim_amount_from_end={end_trim_offset}")

                    history_length = final_history_pixels.shape[2]
                    if history_length > start_trim:
                        # Apply start trim
                        final_history_pixels_sliced_at_start = final_history_pixels[:, :, start_trim:, :, :]
                        print(f"Shape after start trim: {final_history_pixels_sliced_at_start.shape}")

                        # Apply end trim (remove 'end_trim_offset' frames from the end)
                        current_length_after_start_trim = final_history_pixels_sliced_at_start.shape[2]
                        if current_length_after_start_trim > end_trim_offset :
                            # Keep frames from start up to (length - end_trim_offset)
                            final_history_pixels_1loop = final_history_pixels_sliced_at_start[:, :, :current_length_after_start_trim - end_trim_offset, :, :]
                            print(f"Shape after end trim (final 1-loop): {final_history_pixels_1loop.shape}")
                        else:
                            print(f"Warning: Length after start trim ({current_length_after_start_trim}) <= end_trim_offset ({end_trim_offset}). Slicing results in empty or invalid tensor.")
                            final_history_pixels_1loop = torch.empty_like(final_history_pixels_sliced_at_start[:, :, :0]) # Empty tensor
                    else:
                         print(f"Warning: History length ({history_length}) <= start_trim ({start_trim}). Cannot perform Code A slicing.")
                         final_history_pixels_1loop = torch.empty_like(final_history_pixels[:, :, :0]) # Empty tensor


                    if final_history_pixels_1loop.shape[2] > 0:
                        # Determine frame count string for filename
                        frame_count_str_for_filename = total_generated_latent_frames_from_file if latent_input_file else total_generated_latent_frames

                        # Save 1-loop video (Code A saves this as jobid_frames_seed_1loop.mp4)
                        output_filename_1loop = os.path.join(outputs_folder, f'{job_id}_{frame_count_str_for_filename}_{seed}_1loop.mp4') # Match Code A naming
                        save_bcthw_as_mp4(final_history_pixels_1loop, output_filename_1loop, fps=30, crf=mp4_crf)
                        if stream is not None: stream.output_queue.push(('message', f"1-Loop video saved (Code A slice): {os.path.basename(output_filename_1loop)}"))


                        final_history_pixels_repeated = final_history_pixels_1loop.repeat(1, 1, loop_num, 1, 1)
                        output_filename_looped = os.path.join(outputs_folder, f'{job_id}_{frame_count_str_for_filename}_{seed}_loop_{loop_num}.mp4')
                        print(f"Saving {loop_num}-Loop video to: {output_filename_looped}")
                        save_bcthw_as_mp4(final_history_pixels_repeated, output_filename_looped, fps=30, crf=mp4_crf)
                        if stream is not None: stream.output_queue.push(('file', output_filename_looped)) # This is the main output
                    else:
                        print("Skipping video saving as Code A slicing resulted in 0 frames for 1-loop video.")
                        if stream is not None: stream.output_queue.push(('message', "Loop video saving skipped: 0 frames after slicing for 1-loop."))


                else:
                     print("Loop decoding process did not produce final pixel history. Skipping saving.")
                     if stream is not None: stream.output_queue.push(('message', "Loop video saving skipped: Decoding failed or no pixels generated."))


            finally:
                # --- デコードループのクリーンアップ ---
                print("Decode loop finally block executing...")
                if decode_loop_final_latents_gpu is not None:
                    # 元の変数に戻す (GPUメモリ解放のため)
                    final_latents = decode_loop_final_latents_gpu.cpu()
                    decode_loop_final_latents_gpu = None # GPU参照をクリア
                    print("Moved final_latents back to CPU from GPU tensor.")
                if not high_vram:
                    if 'vae' in locals() and vae is not None: # Check if vae was loaded
                        unload_complete_models(vae) # VAEをアンロード
                        print("Unloaded VAE (low VRAM) in decode finally.")
                # Clear other large tensors from GPU if they exist
                pixel_map.clear()
                final_history_pixels = None
                final_history_pixels_1loop = None
                gc.collect()
                torch.cuda.empty_cache()
                print("Decode loop finally block finished.")
            # --- ここまでデコードループ用 try...finally ---

        else: # Handle case where final_latents is None (e.g. latent gen failed or load failed)
             print("Final latents tensor is None. Cannot proceed with loop decoding.")
             if stream is not None: stream.output_queue.push(('message', "Loop video generation skipped: Latent tensor not available."))
        # --- End of third part of Code A block ---

    # --- KeyboardInterrupt をキャッチ ---
    except KeyboardInterrupt:
        print("--- KeyboardInterrupt caught by main worker handler ---")
        if stream is not None: stream.output_queue.push(('message', "Generation stopped by user."))
        # finallyブロックでクリーンアップが行われる
        # loop_workerが'end'を送るのでここでは何もしない
        pass # finallyブロックへ進む

    # --- その他の例外をキャッチ ---
    except Exception as e:
        print(f"--- Error during worker execution ({type(e).__name__}) ---")
        traceback.print_exc()
        if stream is not None:
            stream.output_queue.push(('error', f"An error occurred in worker: {e}"))
        # finallyブロックへ進む

    # --- 常に実行される finally ブロック ---
    finally:
        print("Worker finally block executing...")
        if not high_vram:
            # 全ての可能性のあるモデルをアンロード試行
            models_to_unload = []
            if 'text_encoder' in locals() and text_encoder is not None: models_to_unload.append(text_encoder)
            if 'text_encoder_2' in locals() and text_encoder_2 is not None: models_to_unload.append(text_encoder_2)
            if 'image_encoder' in locals() and image_encoder is not None: models_to_unload.append(image_encoder)
            if 'vae' in locals() and vae is not None: models_to_unload.append(vae)
            if 'transformer' in locals() and transformer is not None: models_to_unload.append(transformer) # Transformerも対象に
            if models_to_unload:
                unload_complete_models(*models_to_unload)
                print(f"Unloaded models in finally (low VRAM): {[m.__class__.__name__ for m in models_to_unload]}")
            else:
                print("No models to unload in finally (low VRAM).")
        else:
            print("Cleanup (High VRAM) - Models presumed to be on GPU. Cleaning CUDA cache.")


        # GPUに残っている可能性のあるテンソルをCPUへ移動 (念のため)
        # And clear references to large tensors
        tensors_to_clear = [
            'final_latents', 'start_latent', 'history_latents', 'post_history_latents',
            'history_pixels', 'post_history_pixels', 'real_history_latents',
            'post_real_history_latents', 'final_history_pixels_1loop',
            'llama_vec', 'llama_vec_n', 'clip_l_pooler', 'clip_l_pooler_n',
            'image_encoder_last_hidden_state', 'input_image_pt', 'generated_latents'
        ]
        for tensor_name in tensors_to_clear:
            if tensor_name in locals():
                tensor_var = locals()[tensor_name]
                if isinstance(tensor_var, torch.Tensor) and tensor_var.is_cuda:
                    print(f"Moving {tensor_name} to CPU in outer finally.")
                    locals()[tensor_name] = tensor_var.cpu()
                # Clear reference
                locals()[tensor_name] = None
            elif tensor_name in globals(): # Should not happen for most of these
                tensor_var = globals()[tensor_name]
                if isinstance(tensor_var, torch.Tensor) and tensor_var.is_cuda:
                    print(f"Moving global {tensor_name} to CPU in outer finally.")
                    globals()[tensor_name] = tensor_var.cpu()
                globals()[tensor_name] = None


        gc.collect()
        torch.cuda.empty_cache()
        print("Worker finally block finished.")
    # --- ここまで finally ブロック ---

    # --- worker関数の終了 ---
    print("Worker function returning.")
    return # 明示的に返す
# --- ここまで worker 関数 ---

# <<< --- START: PRESET HANDLING FUNCTIONS --- >>>
PRESET_FILE = "framepack_presets.json" # <<< CHANGED to common filename

def load_presets(): # <<< REMOVED _kohya
    """Loads presets from the JSON file."""
    if not os.path.exists(PRESET_FILE):
        return {}
    try:
        with open(PRESET_FILE, 'r', encoding='utf-8') as f:
            presets = json.load(f)
        if not isinstance(presets, dict):
            print(f"Warning: Preset file '{PRESET_FILE}' is not a valid JSON dictionary. Ignoring.")
            return {}
        return presets
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from '{PRESET_FILE}'. Please check the file format.")
        return {}
    except Exception as e:
        print(f"Error loading presets from '{PRESET_FILE}': {e}")
        return {}

def save_presets(presets): # <<< REMOVED _kohya
    """Saves the presets dictionary to the JSON file."""
    try:
        with open(PRESET_FILE, 'w', encoding='utf-8') as f:
            json.dump(presets, f, indent=4, ensure_ascii=False)
        print(f"Presets saved to {PRESET_FILE}")
    except Exception as e:
        print(f"Error saving presets to '{PRESET_FILE}': {e}")
        gr.Warning(f"Failed to save presets: {e}")

def get_preset_names(presets): # <<< REMOVED _kohya
    """Returns a list of preset names."""
    return sorted(list(presets.keys()))

def gather_current_settings(*args): # <<< REMOVED _kohya
    """Gathers current settings from UI components into a dictionary."""
    # Order must match preset_components_inputs
    settings = {
        "prompt": args[0],
        "n_prompt": args[1],
        "generation_count": args[2],
        "seed": args[3],
        "total_second_length": args[4],
        "connection_second_length": args[5],
        "padding_second_length": args[6],
        "loop_num": args[7],
        "steps": args[8],
        "gs": args[9],
        "gpu_memory_preservation": args[10],
        "use_teacache": args[11],
        "mp4_crf": args[12],
        "progress_preview_option": args[13],
        "lora_file_1": args[14],
        "lora_multiplier_1": args[15],
        "lora_file_2": args[16],
        "lora_multiplier_2": args[17],
        "lora_file_3": args[18],
        "lora_multiplier_3": args[19],
        "lora_file_4": args[20],
        "lora_multiplier_4": args[21],
        "lora_file_5": args[22],
        "lora_multiplier_5": args[23],
        "fp8_optimization": args[24]
    }
    return settings

def apply_preset(preset_name): # <<< REMOVED _kohya
    """Loads a preset and returns gr.update() objects for UI components."""
    presets = load_presets()
    preset_data = presets.get(preset_name)

    if not preset_data:
        gr.Warning(f"Preset '{preset_name}' not found.")
        num_components = 25 # Based on preset_components_outputs
        return [gr.update() for _ in range(num_components)]

    print(f"Applying preset: {preset_name}")

    updates = [
        gr.update(value=preset_data.get("prompt", "")),
        gr.update(value=preset_data.get("n_prompt", "")),
        gr.update(value=preset_data.get("generation_count", 1)),
        gr.update(value=preset_data.get("seed", 31337)),
        gr.update(value=preset_data.get("total_second_length", 1)),
        gr.update(value=preset_data.get("connection_second_length", 1)),
        gr.update(value=preset_data.get("padding_second_length", 0)),
        gr.update(value=preset_data.get("loop_num", 5)),
        gr.update(value=preset_data.get("steps", 25)),
        gr.update(value=preset_data.get("gs", 10.0)),
        gr.update(value=preset_data.get("gpu_memory_preservation", 8)),
        gr.update(value=preset_data.get("use_teacache", True)),
        gr.update(value=preset_data.get("mp4_crf", 16)),
        gr.update(value=preset_data.get("progress_preview_option", "Reduce Progress File Output")),
        gr.update(value=preset_data.get("lora_file_1")),
        gr.update(value=preset_data.get("lora_multiplier_1", 0.8)),
        gr.update(value=preset_data.get("lora_file_2")),
        gr.update(value=preset_data.get("lora_multiplier_2", 0.0)),
        gr.update(value=preset_data.get("lora_file_3")),
        gr.update(value=preset_data.get("lora_multiplier_3", 0.0)),
        gr.update(value=preset_data.get("lora_file_4")),
        gr.update(value=preset_data.get("lora_multiplier_4", 0.0)),
        gr.update(value=preset_data.get("lora_file_5")),
        gr.update(value=preset_data.get("lora_multiplier_5", 0.0)),
        gr.update(value=preset_data.get("fp8_optimization", False))
    ]
    return updates


def save_preset_action(preset_name, *args): # <<< REMOVED _kohya
    """Gathers settings, saves the preset, and updates the dropdown."""
    if not preset_name:
        gr.Warning("Please enter a name for the preset.")
        return gr.update(), gr.update(value="")

    presets = load_presets()
    current_settings = gather_current_settings(*args)
    presets[preset_name] = current_settings
    save_presets(presets)

    gr.Info(f"Preset '{preset_name}' saved.")
    return gr.update(choices=get_preset_names(presets), value=preset_name), gr.update(value="")


def delete_preset_action(preset_name): # <<< REMOVED _kohya
    """Deletes the selected preset and updates the dropdown."""
    if not preset_name:
        gr.Warning("Please select a preset to delete.")
        return gr.update(), gr.update(value="")

    presets = load_presets()
    if preset_name in presets:
        del presets[preset_name]
        save_presets(presets)
        gr.Info(f"Preset '{preset_name}' deleted.")
        return gr.update(choices=get_preset_names(presets), value=None), gr.update(value="")
    else:
        gr.Warning(f"Preset '{preset_name}' not found for deletion.")
        return gr.update(), gr.update(value="")

def load_preset_ui_action(preset_name): # <<< REMOVED _kohya
    """Handles loading a preset and updating the UI components including the name box."""
    component_updates = apply_preset(preset_name)
    preset_name_update = gr.update(value=preset_name)
    return [preset_name_update] + component_updates

def refresh_preset_dropdown(): # <<< REMOVED _kohya
    """Loads presets and returns an update for the dropdown."""
    presets = load_presets()
    return gr.update(choices=get_preset_names(presets))

# <<< --- END: PRESET HANDLING FUNCTIONS --- >>>


# --- process関数の引数と呼び出しを更新 (No changes required for presets) ---
def process(input_image, prompt, n_prompt, generation_count, seed, total_second_length, connection_second_length, padding_second_length, loop_num, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf, progress_preview_option, latent_input_file, lora_file_1, lora_multiplier_1, lora_file_2, lora_multiplier_2, lora_file_3, lora_multiplier_3, lora_file_4, lora_multiplier_4, lora_file_5, lora_multiplier_5, fp8_optimization):
    global stream # グローバル stream を参照
    options = {
        "All Progress File Output": 1, "Reduce Progress File Output": 2, "Without Preview": 3,
        "Without VAE Decode": 4, "Decode Latent File": 5
    }
    option_val = options.get(progress_preview_option, 2) # Default to Reduce

    # Determine flags based on option_val (as in the original code logic)
    reduce_file_output = (option_val >= 2) # If 2, 3, 4, 5
    without_preview = (option_val >= 3)    # If 3, 4, 5
    output_latent_image = (option_val == 4)
    use_latent_input = (option_val == 5)

    print(f"Processing with Mode: {progress_preview_option} (Value: {option_val})")
    print(f"  Reduce File Output (overwrite intermediate): {reduce_file_output}")
    print(f"  Without Preview (no intermediate decode/save): {without_preview}")
    print(f"  Output Latent Only (.pt/.png): {output_latent_image}")
    print(f"  Use Latent Input File: {use_latent_input}")


    actual_latent_input_filepath = None # Initialize

    # --- 入力チェック ---
    if use_latent_input:
        if latent_input_file is None: # latent_input_file is already a filepath string
             gr.Warning("Latent file is required for 'Decode Latent File' mode.")
             yield None, None, "Error: Latent file is required.", "", gr.update(interactive=True), gr.update(interactive=False), ""
             return
        actual_latent_input_filepath = latent_input_file # It's already a path
        if not os.path.exists(actual_latent_input_filepath):
             gr.Warning(f"Latent file not found: {actual_latent_input_filepath}")
             yield None, None, f"Error: Latent file not found: {os.path.basename(actual_latent_input_filepath)}", "", gr.update(interactive=True), gr.update(interactive=False), ""
             return
        print(f"Using latent input file: {actual_latent_input_filepath}")
        input_image = None # Ensure input_image is None if using latent file
    else: # Not using latent input
        if input_image is None:
             gr.Warning("Input image is required unless using 'Decode Latent File' mode.")
             yield None, None, "Error: Input image is required.", "", gr.update(interactive=True), gr.update(interactive=False), ""
             return
        actual_latent_input_filepath = None # Ensure it's None if not used
    # --- ここまで入力チェック ---

    # Prepare LoRA arguments (already filepaths from gr.File type="filepath")
    lora_files = [f if f else None for f in [lora_file_1, lora_file_2, lora_file_3, lora_file_4, lora_file_5]]
    lora_multipliers = [lora_multiplier_1, lora_multiplier_2, lora_multiplier_3, lora_multiplier_4, lora_multiplier_5]

    print("--- Starting Generation ---")
    if not use_latent_input:
        print(f"Prompt: {prompt}")
        print(f"Negative Prompt: {n_prompt}")
        print(f"Input Image Shape: {input_image.shape if input_image is not None else 'N/A'}")
        print(f"LoRA Files: {[os.path.basename(f) if f else None for f in lora_files]}")
        print(f"LoRA Multipliers: {lora_multipliers}")
        print(f"FP8 Optimization: {fp8_optimization}")
    print(f"Generation Count: {generation_count}")
    print(f"Seed: {seed} (Used if Generation Count is 1, otherwise random per iteration)")
    print(f"Main Sections: {total_second_length}, Connection Sections: {connection_second_length}") # UI labels are "Sections"
    print(f"Padding Ratio: {padding_second_length}, Loop Count: {loop_num}")
    print(f"Steps: {steps}, Distilled CFG: {gs}, TeaCache: {use_teacache}, CRF: {mp4_crf}")


    # --- UI更新: 開始 ---
    yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True), ''

    # --- 新しいストリームを作成 ---
    stream = AsyncStream()
    print("Starting worker thread...")

    # --- loop_workerを非同期実行 ---
    async_run(loop_worker, input_image, prompt, n_prompt, generation_count, seed,
              total_second_length, connection_second_length, padding_second_length, loop_num,
              latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf,
              reduce_file_output, without_preview, output_latent_image, actual_latent_input_filepath, # Pass processed filepath
              lora_files, lora_multipliers, fp8_optimization)

    output_filename = None
    final_message = "Generation finished." # デフォルトの終了メッセージ

    # --- 結果待機ループ ---
    while True:
        try:
            flag, data = stream.output_queue.next() # Default timeout is 1.0 second

            if flag == 'file':
                output_filename = data
                print(f"Received file: {output_filename}")
                yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True), gr.update()
            elif flag == 'progress':
                preview, desc, html = data
                preview_update = gr.update(visible=preview is not None, value=preview)
                yield gr.update(value=output_filename), preview_update, desc, html, gr.update(interactive=False), gr.update(interactive=True), gr.update()
            elif flag == 'generation count':
                generation_count_text = data # Already formatted string with seed
                yield gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), generation_count_text
            elif flag == 'message':
                print(f"Message: {data}")
                yield gr.update(), gr.update(), data, gr.update(), gr.update(), gr.update(), gr.update()
            elif flag == 'warning':
                 print(f"Warning: {data}")
                 gr.Warning(data) # Display Gradio warning
                 yield gr.update(), gr.update(), f"Warning: {data}", gr.update(), gr.update(), gr.update(), gr.update() # Update desc
            elif flag == 'error':
                 error_message = data
                 print(f"Error received from worker: {error_message}")
                 gr.Error(f"Generation Error: {error_message}") # Display Gradio error
                 final_message = f"Error: {error_message}" # Set for final display
                 # Do not break here; wait for 'end' signal to correctly update buttons
                 yield output_filename, gr.update(visible=False), final_message, "", gr.update(), gr.update(), "" # Clear progress
            elif flag == 'end':
                 print(f"Received 'end' signal. Final message: {final_message}")
                 yield output_filename, gr.update(visible=False, value=None), final_message, '', gr.update(interactive=True), gr.update(interactive=False), ''
                 stream = None # Clear the stream object
                 break # Exit the loop
            elif flag is None and data is None: # Should not happen with default next()
                 print("Warning: Received None flag/data from stream queue (timeout or unexpected).")
                 # Potentially update UI to indicate possible stall or allow retry
                 yield gr.update(value=output_filename), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
            else:
                print(f"Warning: Received unknown flag '{flag}' from worker.")


        except StopIteration: # Should not happen if worker always sends 'end'
             print("Stream queue StopIteration (worker might have finished unexpectedly without 'end' signal).")
             final_message = final_message if final_message != "Generation finished." else "Worker finished unexpectedly."
             yield output_filename, gr.update(visible=False), final_message, "", gr.update(interactive=True), gr.update(interactive=False), ""
             stream = None
             break
        except Exception as e: # Catch other errors in the UI loop
            print(f"Error in process loop (UI): {e}")
            traceback.print_exc()
            final_message = f"UI Error: {e}"
            yield output_filename, gr.update(visible=False), final_message, "", gr.update(interactive=True), gr.update(interactive=False), ""
            stream = None
            break
# --- ここまで ---

def end_process():
    global stream # グローバル stream を参照
    print("End button clicked. Sending 'end' signal to worker.")
    if stream is not None and hasattr(stream, 'input_queue') and stream.input_queue is not None:
        try:
            if stream.input_queue.top() != 'end': # Avoid pushing multiple 'end' signals
                stream.input_queue.push('end')
                print("Sent 'end' signal.")
            else:
                print("'end' signal already sent.")
        except Exception as e: # Catch potential errors during push (e.g., queue closed)
            print(f"Error sending 'end' signal: {e}")
    else:
        print("Stream or input queue not initialized. Cannot send signal.")

quick_prompts = [
    'The girl dances gracefully, with clear movements, full of charm.',
    'A character doing some simple body movements.',
]
quick_prompts = [[x] for x in quick_prompts]

css = make_progress_bar_css()
block = gr.Blocks(css=css).queue()
with block:
    gr.Markdown('# FramePackLoop with LoRA & FP8 & Presets') # <<< Title Updated
    with gr.Row():
        with gr.Column():
            # <<< --- START: PRESET UI --- >>>
            with gr.Group():
                gr.Markdown("### Preset Management")
                with gr.Row():
                    preset_dropdown = gr.Dropdown(label="Select Preset", choices=get_preset_names(load_presets())) # <<< REMOVED _kohya
                    refresh_presets_button = gr.Button("🔄 Refresh") # <<< REMOVED _kohya
                with gr.Row():
                    preset_name_textbox = gr.Textbox(label="Preset Name (for saving/overwriting)") # <<< REMOVED _kohya
                with gr.Row():
                    load_preset_button = gr.Button("Load Selected Preset") # <<< REMOVED _kohya
                    save_preset_button = gr.Button("Save/Overwrite Preset") # <<< REMOVED _kohya
                    delete_preset_button = gr.Button("Delete Selected Preset") # <<< REMOVED _kohya
            # <<< --- END: PRESET UI --- >>>

            input_image = gr.Image(sources='upload', type="numpy", label="Image (Ignored if Latent File used, NOT saved in presets)", height=320)
            latent_input_file = gr.File(label="Latent File Input (.pt, NOT saved in presets)", file_types=[".pt"], type="filepath")

            prompt = gr.Textbox(label="Prompt", value='')
            example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Quick List', samples_per_page=1000, components=[prompt])
            example_quick_prompts.click(lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=True)

            with gr.Row():
                start_button = gr.Button(value="Start Generation")
                end_button = gr.Button(value="End Generation", interactive=False)

            with gr.Group():
                use_teacache = gr.Checkbox(label='Use TeaCache', value=True, info='Faster speed, but often makes hands and fingers slightly worse.')
                n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=True) # Visible for preset
                generation_count = gr.Slider(label="Generation Count", minimum=1, maximum=500, value=1, step=1, info='If > 1, Seed will be random per iteration.')
                progress_preview_option = gr.Radio(
                    choices=["All Progress File Output", "Reduce Progress File Output", "Without Preview", "Without VAE Decode", "Decode Latent File"],
                    label="Progress/Mode Option",
                    info="Controls preview, file output, and operation mode.",
                    value="Reduce Progress File Output"
                )
                gr.Markdown("""
                    - **All Progress File Output**: Saves all intermediate files.
                    - **Reduce Progress File Output**: Overwrites intermediate files (system_preview.mp4).
                    - **Without Preview**: No preview, faster final output. Saves only the final loop video.
                    - **Without VAE Decode**: Outputs latent tensor file (.pt) and optionally a decoded latent image (.png). Skips final loop decoding. **Requires LoRA/FP8 config during generation if used later.**
                    - **Decode Latent File**: Generates video from the provided Latent File Input (.pt). **Requires matching section lengths below.** Input image/prompt/LoRA/FP8 are ignored.
                    """,)

                seed = gr.Number(label="Seed", value=31337, precision=0, info='Used only when Generation Count is 1 (or for the first iteration if >1 and not random).')
                total_second_length = gr.Slider(label="Main Video Sections (Generation)", minimum=1, maximum=120, value=1, step=1, info="Number of main generation sections (~0.3s each based on 9-frame window). Ignored for 'Decode Latent File'.")
                connection_second_length = gr.Slider(label="Connection Video Sections (Generation/Decode)", minimum=1, maximum=120, value=1, step=1, info="Number of connection sections (~0.3s each). **Must match the latent file's original connection sections if using 'Decode Latent File' mode.**")
                latent_window_size = gr.Slider(label="Latent Window Size", minimum=9, maximum=9, value=9, step=1, visible=False) # Not in preset

                def update_text(slider_value, latent_window_size_val):
                    if isinstance(slider_value, (int, float)) and isinstance(latent_window_size_val, (int, float)):
                         return str(int(slider_value * latent_window_size_val))
                    return "Invalid input"

                with gr.Group():
                    padding_second_length = gr.Slider(label="Padding Section Ratio (Generation)", minimum=0, maximum=10, value=0, step=0.1, info="Initial padding = ratio * Latent Window Size frames. 0 uses original image, >0 moves away. Ignored for 'Decode Latent File'.")
                    padding_frames_text = gr.Textbox(label="Padding Frames (Approx)", interactive=False, info="Effective padding frames.")
                    padding_second_length.change(fn=update_text, inputs=[padding_second_length, latent_window_size], outputs=padding_frames_text, queue=False)
                    # block.load moved to end

                loop_num = gr.Slider(label="Loop Count (Final Video)", minimum=1, maximum=100, value=5, step=1, info='Number of times the generated loop is repeated in the final output.')
                steps = gr.Slider(label="Steps (Generation)", minimum=1, maximum=100, value=25, step=1, info='Changing this value is not recommended. Ignored for "Decode Latent File".')
                cfg = gr.Slider(label="CFG Scale (Generation)", minimum=1.0, maximum=32.0, value=1.0, step=0.01, visible=False) # Not in preset
                gs = gr.Slider(label="Distilled CFG Scale (Generation)", minimum=1.0, maximum=32.0, value=10.0, step=0.01, info='Changing this value is not recommended. Ignored for "Decode Latent File".')
                rs = gr.Slider(label="CFG Re-Scale (Generation)", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) # Not in preset
                gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB)", minimum=4, maximum=128, value=8, step=0.1, info="Memory reserved for other models (like VAE). Larger = slower inference but less OOM risk.")
                mp4_crf = gr.Slider(label="MP4 CRF (Quality)", minimum=0, maximum=51, value=16, step=1, info="Lower means better quality (0=lossless, ~16-28=good). Affects final MP4 output.")

            with gr.Group(visible=lora_fp8_available) as lora_group: # <<< REMOVED _kohya
                gr.Markdown("### LoRA Settings (Generation Only)")
                with gr.Row():
                    lora_file_1 = gr.File(label="LoRA 1 File", file_count="single", file_types=[".safetensors", ".ckpt", ".pt"], type="filepath")
                    lora_multiplier_1 = gr.Slider(label="LoRA 1 Mult", minimum=-1.0, maximum=2.0, value=0.8, step=0.05)
                with gr.Row():
                    lora_file_2 = gr.File(label="LoRA 2 File", file_count="single", file_types=[".safetensors", ".ckpt", ".pt"], type="filepath")
                    lora_multiplier_2 = gr.Slider(label="LoRA 2 Mult", minimum=-1.0, maximum=2.0, value=0.0, step=0.05)
                with gr.Row():
                    lora_file_3 = gr.File(label="LoRA 3 File", file_count="single", file_types=[".safetensors", ".ckpt", ".pt"], type="filepath")
                    lora_multiplier_3 = gr.Slider(label="LoRA 3 Mult", minimum=-1.0, maximum=2.0, value=0.0, step=0.05)
                with gr.Row():
                    lora_file_4 = gr.File(label="LoRA 4 File", file_count="single", file_types=[".safetensors", ".ckpt", ".pt"], type="filepath")
                    lora_multiplier_4 = gr.Slider(label="LoRA 4 Mult", minimum=-1.0, maximum=2.0, value=0.0, step=0.05)
                with gr.Row():
                    lora_file_5 = gr.File(label="LoRA 5 File", file_count="single", file_types=[".safetensors", ".ckpt", ".pt"], type="filepath")
                    lora_multiplier_5 = gr.Slider(label="LoRA 5 Mult", minimum=-1.0, maximum=2.0, value=0.0, step=0.05)

            with gr.Group(visible=lora_fp8_available) as opt_group: # <<< REMOVED _kohya
                 gr.Markdown("### Optimization (Generation Only)")
                 fp8_optimization = gr.Checkbox(label="FP8 Optimization (Experimental)", value=False, info="Reduces memory usage and potentially speeds up inference, requires Ampere+ GPU. May affect quality. Ignored for 'Decode Latent File'.")

            # Dummy components if LoRA/FP8 not available
            if not lora_fp8_available:
                 with lora_group: # <<< REMOVED _kohya
                     lora_file_1 = gr.File(type="filepath", visible=False, value=None)
                     lora_multiplier_1 = gr.Slider(visible=False, value=0.0)
                     lora_file_2 = gr.File(type="filepath", visible=False, value=None)
                     lora_multiplier_2 = gr.Slider(visible=False, value=0.0)
                     lora_file_3 = gr.File(type="filepath", visible=False, value=None)
                     lora_multiplier_3 = gr.Slider(visible=False, value=0.0)
                     lora_file_4 = gr.File(type="filepath", visible=False, value=None)
                     lora_multiplier_4 = gr.Slider(visible=False, value=0.0)
                     lora_file_5 = gr.File(type="filepath", visible=False, value=None)
                     lora_multiplier_5 = gr.Slider(visible=False, value=0.0)
                 with opt_group: # <<< REMOVED _kohya
                     fp8_optimization = gr.Checkbox(value=False, visible=False)


        with gr.Column():
            preview_image = gr.Image(label="Sampling Preview", height=200, visible=False)
            result_video = gr.Video(label="Result Video", autoplay=True, show_share_button=False, height=512, loop=True)
            gr.Markdown('Note: Due to inverted sampling, ending actions might appear first. The start will be generated later.')
            progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
            progress_bar = gr.HTML('', elem_classes='no-generating-animation')
            progress_gcounter = gr.Markdown('', elem_classes='no-generating-animation')

    ips = [
        input_image, prompt, n_prompt, generation_count, seed,
        total_second_length, connection_second_length, padding_second_length, loop_num,
        latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf,
        progress_preview_option, latent_input_file,
        lora_file_1, lora_multiplier_1, lora_file_2, lora_multiplier_2, lora_file_3, lora_multiplier_3,
        lora_file_4, lora_multiplier_4, lora_file_5, lora_multiplier_5, fp8_optimization
    ]
    outputs = [result_video, preview_image, progress_desc, progress_bar, start_button, end_button, progress_gcounter]

    # <<< --- START: PRESET COMPONENT LISTS --- >>>
    preset_components_inputs = [ # <<< REMOVED _kohya
        prompt, n_prompt, generation_count, seed,
        total_second_length, connection_second_length, padding_second_length, loop_num,
        steps, gs, gpu_memory_preservation, use_teacache, mp4_crf, progress_preview_option,
        lora_file_1, lora_multiplier_1, lora_file_2, lora_multiplier_2,
        lora_file_3, lora_multiplier_3, lora_file_4, lora_multiplier_4,
        lora_file_5, lora_multiplier_5,
        fp8_optimization
    ]

    preset_components_outputs = [ # <<< REMOVED _kohya
        preset_name_textbox, # <<< REMOVED _kohya (for load_preset_ui_action)
        prompt, n_prompt, generation_count, seed,
        total_second_length, connection_second_length, padding_second_length, loop_num,
        steps, gs, gpu_memory_preservation, use_teacache, mp4_crf, progress_preview_option,
        lora_file_1, lora_multiplier_1, lora_file_2, lora_multiplier_2,
        lora_file_3, lora_multiplier_3, lora_file_4, lora_multiplier_4,
        lora_file_5, lora_multiplier_5,
        fp8_optimization
    ]
    # <<< --- END: PRESET COMPONENT LISTS --- >>>


    # --- イベントハンドラ ---
    start_button.click(fn=process, inputs=ips, outputs=outputs, queue=True)
    end_button.click(fn=end_process, inputs=None, outputs=None, queue=False)

    # <<< --- START: PRESET EVENT HANDLERS --- >>>
    load_preset_button.click( # <<< REMOVED _kohya
        fn=load_preset_ui_action, # <<< REMOVED _kohya
        inputs=[preset_dropdown], # <<< REMOVED _kohya
        outputs=preset_components_outputs # <<< REMOVED _kohya
    )
    save_preset_button.click( # <<< REMOVED _kohya
        fn=save_preset_action, # <<< REMOVED _kohya
        inputs=[preset_name_textbox] + preset_components_inputs, # <<< REMOVED _kohya
        outputs=[preset_dropdown, preset_name_textbox] # <<< REMOVED _kohya
    )
    delete_preset_button.click( # <<< REMOVED _kohya
        fn=delete_preset_action, # <<< REMOVED _kohya
        inputs=[preset_dropdown], # <<< REMOVED _kohya
        outputs=[preset_dropdown, preset_name_textbox] # <<< REMOVED _kohya
    )
    refresh_presets_button.click( # <<< REMOVED _kohya
        fn=refresh_preset_dropdown, # <<< REMOVED _kohya
        inputs=None,
        outputs=[preset_dropdown] # <<< REMOVED _kohya
    )
    preset_dropdown.change( # <<< REMOVED _kohya
        fn=lambda x: x,
        inputs=[preset_dropdown], # <<< REMOVED _kohya
        outputs=[preset_name_textbox] # <<< REMOVED _kohya
    )
    # <<< --- END: PRESET EVENT HANDLERS --- >>>

    block.load(
        fn=lambda: (update_text(0, 9), refresh_preset_dropdown()), # Initial padding text and preset dropdown # <<< REMOVED _kohya
        inputs=None,
        outputs=[padding_frames_text, preset_dropdown] # <<< REMOVED _kohya
    )

# --- サーバー起動 ---
block.launch(
    server_name=args.server,
    server_port=args.port,
    share=args.share,
    inbrowser=args.inbrowser,
)