import json
from pathlib import Path
from typing import List

import click
from bs4 import BeautifulSoup

from chandra.input import load_file
from chandra.model import InferenceManager
from chandra.model.schema import BatchInputItem
from chandra.output import translate_image_alts_to_korean


def html_to_structured_json(html_content: str) -> dict:
    """HTML을 구조화된 JSON으로 변환"""
    soup = BeautifulSoup(html_content, 'html.parser')

    result = {
        "title": "",
        "sections": [],
        "tables": [],
        "paragraphs": [],
        "lists": [],
        "forms": []
    }

    # 제목 추출
    title = soup.find(['h1', 'h2', 'h3', 'h4'])
    if title:
        result["title"] = title.get_text(strip=True)

    # 섹션 헤더 추출 (data-label 기반)
    for div in soup.find_all('div', attrs={'data-label': True}):
        label = div.get('data-label')
        bbox = div.get('data-bbox', '[]')

        section = {
            "type": label,
            "bbox": bbox,
            "content": div.get_text(strip=True)
        }
        result["sections"].append(section)

    # 테이블 추출
    for table in soup.find_all('table'):
        table_data = {
            "headers": [],
            "rows": []
        }

        # 헤더 추출
        thead = table.find('thead')
        if thead:
            for tr in thead.find_all('tr'):
                header_row = []
                for th in tr.find_all('th'):
                    header_row.append({
                        "text": th.get_text(strip=True),
                        "rowspan": int(th.get('rowspan', 1)),
                        "colspan": int(th.get('colspan', 1))
                    })
                if header_row:
                    table_data["headers"].append(header_row)

        # 바디 추출
        tbody = table.find('tbody') or table
        for tr in tbody.find_all('tr', recursive=False):
            row = []
            for td in tr.find_all(['td', 'th']):
                cell = {
                    "text": td.get_text(strip=True),
                    "rowspan": int(td.get('rowspan', 1)),
                    "colspan": int(td.get('colspan', 1))
                }

                # 체크박스 확인
                checkbox = td.find('input', {'type': 'checkbox'})
                if checkbox:
                    cell["has_checkbox"] = True
                    cell["checked"] = checkbox.get('checked') is not None

                # 입력 필드 확인
                input_field = td.find('input', {'type': lambda x: x and x != 'checkbox'})
                if input_field:
                    cell["has_input"] = True
                    cell["input_type"] = input_field.get('type', 'text')
                    cell["input_value"] = input_field.get('value', '')

                row.append(cell)
            if row:
                table_data["rows"].append(row)

        if table_data["headers"] or table_data["rows"]:
            result["tables"].append(table_data)

    # 폼 필드 추출
    for input_elem in soup.find_all('input'):
        input_type = input_elem.get('type', 'text')
        form_field = {
            "type": input_type,
            "value": input_elem.get('value', ''),
            "checked": input_elem.get('checked') is not None if input_type == 'checkbox' else None
        }
        result["forms"].append(form_field)

    # 단락 추출
    for p in soup.find_all('p'):
        text = p.get_text(strip=True)
        if text:
            result["paragraphs"].append(text)

    # 리스트 추출
    for ol in soup.find_all(['ol', 'ul']):
        list_items = []
        for li in ol.find_all('li', recursive=False):
            list_items.append(li.get_text(strip=True))
        if list_items:
            result["lists"].append({
                "type": "ordered" if ol.name == 'ol' else "unordered",
                "items": list_items
            })

    return result


def get_supported_files(input_path: Path) -> List[Path]:
    """Get list of supported image/PDF files from path."""
    supported_extensions = {
        ".pdf",
        ".png",
        ".jpg",
        ".jpeg",
        ".gif",
        ".webp",
        ".tiff",
        ".bmp",
    }

    if input_path.is_file():
        if input_path.suffix.lower() in supported_extensions:
            return [input_path]
        else:
            raise click.BadParameter(f"Unsupported file type: {input_path.suffix}")

    elif input_path.is_dir():
        files = []
        for ext in supported_extensions:
            files.extend(input_path.glob(f"*{ext}"))
            files.extend(input_path.glob(f"*{ext.upper()}"))
        return sorted(files)

    else:
        raise click.BadParameter(f"Path does not exist: {input_path}")


def save_merged_output(
    output_dir: Path,
    file_name: str,
    results: List,
    save_images: bool = True,
    save_html: bool = True,
    paginate_output: bool = False,
):
    """Save merged OCR results for all pages to output directory."""
    # Create subfolder for this file
    safe_name = Path(file_name).stem
    file_output_dir = output_dir / safe_name
    file_output_dir.mkdir(parents=True, exist_ok=True)

    # Merge all pages
    all_markdown = []
    all_html = []
    all_metadata = []
    total_tokens = 0
    total_chunks = 0
    total_images = 0

    # Process each page result
    for page_num, result in enumerate(results):
        # Add page separator for multi-page documents
        if page_num > 0 and paginate_output:
            all_markdown.append(f"\n\n{page_num}" + "-" * 48 + "\n\n")
            all_html.append(f"\n\n<!-- Page {page_num + 1} -->\n\n")

        all_markdown.append(result.markdown)
        all_html.append(result.html)

        # Accumulate metadata
        total_tokens += result.token_count
        total_chunks += len(result.chunks)
        total_images += len(result.images)

        page_metadata = {
            "page_num": page_num,
            "page_box": result.page_box,
            "token_count": result.token_count,
            "num_chunks": len(result.chunks),
            "num_images": len(result.images),
        }
        all_metadata.append(page_metadata)

        # Save extracted images if requested
        if save_images and result.images:
            images_dir = file_output_dir
            images_dir.mkdir(exist_ok=True)

            for img_name, pil_image in result.images.items():
                img_path = images_dir / img_name
                pil_image.save(img_path)

    # Merge and translate HTML before saving
    merged_html = "".join(all_html)

    # Translate image alt attributes to Korean
    click.echo("  번역 중: 이미지 설명을 한국어로 변환...")
    translated_html = translate_image_alts_to_korean(merged_html)

    # Translate markdown (supports both HTML <img> and Markdown ![alt](src) formats)
    merged_markdown = "".join(all_markdown)
    translated_markdown = translate_image_alts_to_korean(merged_markdown)

    # Save merged markdown
    markdown_path = file_output_dir / f"{safe_name}.md"
    with open(markdown_path, "w", encoding="utf-8") as f:
        f.write(translated_markdown)

    # Save merged HTML if requested
    if save_html:
        html_path = file_output_dir / f"{safe_name}.html"
        with open(html_path, "w", encoding="utf-8") as f:
            f.write(translated_html)

    # Save structured JSON (using translated HTML)
    structured_data = html_to_structured_json(translated_html)

    # Add chunks data from all pages (translate image alts in chunks too)
    all_chunks = []
    for page_num, result in enumerate(results):
        if hasattr(result, 'chunks') and result.chunks:
            # Translate chunk content if it contains img tags
            translated_chunks = []
            for chunk in result.chunks:
                translated_chunk = chunk.copy()
                if 'content' in translated_chunk and '<img' in translated_chunk['content']:
                    translated_chunk['content'] = translate_image_alts_to_korean(translated_chunk['content'])
                translated_chunks.append(translated_chunk)

            page_chunks = {
                "page_num": page_num,
                "chunks": translated_chunks
            }
            all_chunks.append(page_chunks)

    structured_data["chunks"] = all_chunks

    # Extract title from chunks if not found in HTML
    if not structured_data["title"] and all_chunks:
        for page_data in all_chunks:
            for chunk in page_data.get("chunks", []):
                if chunk.get("label") == "Section-Header":
                    # Extract text from HTML content
                    from bs4 import BeautifulSoup
                    chunk_soup = BeautifulSoup(chunk.get("content", ""), "html.parser")
                    title_text = chunk_soup.get_text(strip=True)
                    if title_text:
                        structured_data["title"] = title_text
                        break
            if structured_data["title"]:
                break

    structured_json_path = file_output_dir / f"{safe_name}.json"
    with open(structured_json_path, "w", encoding="utf-8") as f:
        json.dump(structured_data, f, ensure_ascii=False, indent=2)

    # Save combined metadata
    metadata = {
        "file_name": file_name,
        "num_pages": len(results),
        "total_token_count": total_tokens,
        "total_chunks": total_chunks,
        "total_images": total_images,
        "pages": all_metadata,
    }
    metadata_path = file_output_dir / f"{safe_name}_metadata.json"
    with open(metadata_path, "w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=2)

    click.echo(f"  Saved: {markdown_path} ({len(results)} page(s))")
    click.echo(f"  Saved: {structured_json_path} (JSON with structured data and chunks)")


@click.command()
@click.argument("input_path", type=click.Path(exists=True, path_type=Path))
@click.argument("output_path", type=click.Path(path_type=Path))
@click.option(
    "--method",
    type=click.Choice(["hf", "vllm"], case_sensitive=False),
    default="vllm",
    help="Inference method: 'hf' for local model, 'vllm' for vLLM server.",
)
@click.option(
    "--page-range",
    type=str,
    default=None,
    help="Page range for PDFs (e.g., '1-5,7,9-12'). Only applicable to PDF files.",
)
@click.option(
    "--max-output-tokens",
    type=int,
    default=None,
    help="Maximum number of output tokens per page.",
)
@click.option(
    "--max-workers",
    type=int,
    default=None,
    help="Maximum number of parallel workers for vLLM inference.",
)
@click.option(
    "--max-retries",
    type=int,
    default=None,
    help="Maximum number of retries for vLLM inference.",
)
@click.option(
    "--include-images/--no-images",
    default=True,
    help="Include images in output.",
)
@click.option(
    "--include-headers-footers/--no-headers-footers",
    default=False,
    help="Include page headers and footers in output.",
)
@click.option(
    "--save-html/--no-html",
    default=True,
    help="Save HTML output files.",
)
@click.option(
    "--batch-size",
    type=int,
    default=None,
    help="Number of pages to process in a batch.",
)
@click.option(
    "--paginate_output",
    is_flag=True,
    default=False,
)
def main(
    input_path: Path,
    output_path: Path,
    method: str,
    page_range: str,
    max_output_tokens: int,
    max_workers: int,
    max_retries: int,
    include_images: bool,
    include_headers_footers: bool,
    save_html: bool,
    batch_size: int,
    paginate_output: bool,
):
    if method == "hf":
        click.echo(
            "When using '--method hf', ensure that the batch size is set correctly.  We will default to batch size of 1."
        )
        if batch_size is None:
            batch_size = 1
    elif method == "vllm":
        if batch_size is None:
            batch_size = 28

    click.echo("Chandra CLI - Starting OCR processing")
    click.echo(f"Input: {input_path}")
    click.echo(f"Output: {output_path}")
    click.echo(f"Method: {method}")

    # Create output directory
    output_path.mkdir(parents=True, exist_ok=True)

    # Load model
    click.echo(f"\nLoading model with method '{method}'...")
    model = InferenceManager(method=method)
    click.echo("Model loaded successfully.")

    # Get files to process
    files_to_process = get_supported_files(input_path)
    click.echo(f"\nFound {len(files_to_process)} file(s) to process.")

    if not files_to_process:
        click.echo("No supported files found. Exiting.")
        return

    # Process each file
    for file_idx, file_path in enumerate(files_to_process, 1):
        click.echo(
            f"\n[{file_idx}/{len(files_to_process)}] Processing: {file_path.name}"
        )

        try:
            # Load images from file
            config = {"page_range": page_range} if page_range else {}
            images = load_file(str(file_path), config)
            click.echo(f"  Loaded {len(images)} page(s)")

            # Accumulate all results for this document
            all_results = []

            # Process pages in batches
            for batch_start in range(0, len(images), batch_size):
                batch_end = min(batch_start + batch_size, len(images))
                batch_images = images[batch_start:batch_end]

                # Create batch input items
                batch = [
                    BatchInputItem(image=img, prompt_type="ocr_layout")
                    for img in batch_images
                ]

                # Run inference
                click.echo(f"  Processing pages {batch_start + 1}-{batch_end}...")

                # Build kwargs for generate
                generate_kwargs = {
                    "include_images": include_images,
                    "include_headers_footers": include_headers_footers,
                }

                if max_output_tokens is not None:
                    generate_kwargs["max_output_tokens"] = max_output_tokens

                if method == "vllm":
                    if max_workers is not None:
                        generate_kwargs["max_workers"] = max_workers
                    if max_retries is not None:
                        generate_kwargs["max_retries"] = max_retries

                results = model.generate(batch, **generate_kwargs)
                all_results.extend(results)

            # Save merged output for all pages
            save_merged_output(
                output_path,
                file_path.name,
                all_results,
                save_images=include_images,
                save_html=save_html,
                paginate_output=paginate_output,
            )

            click.echo(f"  Completed: {file_path.name}")

        except Exception as e:
            click.echo(f"  Error processing {file_path.name}: {e}", err=True)
            continue

    click.echo(f"\nProcessing complete. Results saved to: {output_path}")


if __name__ == "__main__":
    main()
