#!/usr/bin/env python3
"""
오프라인 설치를 위한 RAG 임베딩 모델 다운로드 스크립트

이 스크립트는 온라인 환경에서 실행하여 필요한 모델을 다운로드합니다.
다운로드된 모델은 ~/.airun/models에 저장되며, 오프라인 환경으로 복사할 수 있습니다.

사용법:
    # 모든 모델 다운로드
    python3 scripts/download_offline_models.py

    # 특정 모델만 다운로드
    python3 scripts/download_offline_models.py --model nlpai-lab/KURE-v1

    # 다운로드 경로 지정
    python3 scripts/download_offline_models.py --output /path/to/models
"""

import os
import sys
import argparse
from pathlib import Path

# 필요한 모델 목록
REQUIRED_MODELS = {
    'embedding': 'nlpai-lab/KURE-v1',
    'semantic': 'snunlp/KR-SBERT-V40K-klueNLI-augSTS',
    'image': 'Bingsu/clip-vit-base-patch32-ko'
}

def download_model(model_name, output_dir, model_type='embedding'):
    """HuggingFace에서 모델 다운로드"""
    print(f"\n{'='*80}")
    print(f"모델 다운로드 중: {model_name} ({model_type})")
    print(f"저장 경로: {output_dir}")
    print(f"{'='*80}\n")

    try:
        if model_type in ['embedding', 'semantic']:
            # SentenceTransformer 모델 다운로드
            from sentence_transformers import SentenceTransformer

            print(f"[1/2] 모델 다운로드 중...")
            model = SentenceTransformer(model_name, cache_folder=output_dir)

            print(f"[2/2] 모델 검증 중...")
            # 간단한 테스트로 모델 검증
            test_embedding = model.encode(["테스트 문장"])
            print(f"✅ 임베딩 차원: {len(test_embedding[0])}")

        elif model_type == 'image':
            # CLIP 이미지 모델 다운로드
            from transformers import CLIPProcessor, CLIPModel

            print(f"[1/3] CLIP 모델 다운로드 중...")
            model = CLIPModel.from_pretrained(model_name, cache_dir=output_dir)

            print(f"[2/3] CLIP 프로세서 다운로드 중...")
            processor = CLIPProcessor.from_pretrained(model_name, cache_dir=output_dir)

            print(f"[3/3] 모델 검증 중...")
            # 간단한 검증
            print(f"✅ 모델 로드 성공")

        print(f"\n✅ 모델 다운로드 완료: {model_name}")
        return True

    except Exception as e:
        print(f"\n❌ 모델 다운로드 실패: {model_name}")
        print(f"에러: {str(e)}")
        return False

def check_dependencies():
    """필요한 Python 패키지 확인"""
    required_packages = [
        'sentence-transformers',
        'transformers',
        'torch'
    ]

    missing_packages = []

    print("의존성 패키지 확인 중...\n")
    for package in required_packages:
        try:
            __import__(package.replace('-', '_'))
            print(f"✅ {package}")
        except ImportError:
            print(f"❌ {package} (설치 필요)")
            missing_packages.append(package)

    if missing_packages:
        print(f"\n경고: 다음 패키지가 필요합니다:")
        print(f"pip install {' '.join(missing_packages)}\n")
        return False

    return True

def main():
    parser = argparse.ArgumentParser(
        description='오프라인 설치를 위한 RAG 임베딩 모델 다운로드',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
예제:
    # 모든 모델 다운로드
    python3 %(prog)s

    # 특정 모델만 다운로드
    python3 %(prog)s --model nlpai-lab/KURE-v1

    # 다운로드 경로 지정
    python3 %(prog)s --output /path/to/models
        """
    )

    parser.add_argument(
        '--model',
        choices=['embedding', 'semantic', 'image', 'all'],
        default='all',
        help='다운로드할 모델 타입 (기본값: all)'
    )

    parser.add_argument(
        '--output',
        type=str,
        default=os.path.join(os.path.expanduser('~'), '.airun', 'models'),
        help='모델 저장 경로 (기본값: ~/.airun/models)'
    )

    parser.add_argument(
        '--skip-check',
        action='store_true',
        help='의존성 확인 건너뛰기'
    )

    args = parser.parse_args()

    print("=" * 80)
    print("AIRUN - 오프라인 모델 다운로드 도구")
    print("=" * 80)
    print()

    # 의존성 확인
    if not args.skip_check:
        if not check_dependencies():
            print("\n에러: 필요한 패키지를 먼저 설치해주세요.")
            return 1

    # 출력 디렉토리 생성
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"\n저장 경로: {output_dir}\n")

    # 다운로드할 모델 결정
    models_to_download = {}
    if args.model == 'all':
        models_to_download = REQUIRED_MODELS
    else:
        models_to_download = {args.model: REQUIRED_MODELS[args.model]}

    # 모델 다운로드
    success_count = 0
    total_count = len(models_to_download)

    for model_type, model_name in models_to_download.items():
        if download_model(model_name, str(output_dir), model_type):
            success_count += 1

    # 결과 요약
    print("\n" + "=" * 80)
    print(f"다운로드 완료: {success_count}/{total_count} 성공")
    print("=" * 80)

    if success_count == total_count:
        print(f"\n✅ 모든 모델이 성공적으로 다운로드되었습니다.")
        print(f"\n다음 단계:")
        print(f"1. 오프라인 환경으로 모델 복사:")
        print(f"   rsync -avz {output_dir}/ target-machine:~/.airun/models/")
        print(f"\n2. 오프라인 환경에서 환경 변수 설정:")
        print(f"   export RAG_OFFLINE_MODE=true")
        print(f"   export HF_HUB_OFFLINE=1")
        print(f"\n3. RAG 서비스 재시작:")
        print(f"   echo 'exitem08' | sudo -S systemctl restart airun-rag")
        return 0
    else:
        print(f"\n⚠️  일부 모델 다운로드에 실패했습니다.")
        return 1

if __name__ == '__main__':
    sys.exit(main())
