#!/usr/bin/env python3
"""
VRAM 사용량 확인 스크립트
모델 로드 전후의 GPU 메모리 사용량을 측정합니다.
"""

import torch
import gc

def print_gpu_memory():
    """GPU 메모리 사용량 출력"""
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            total = torch.cuda.get_device_properties(i).total_memory / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            free = reserved - allocated

            print(f"GPU {i}: {torch.cuda.get_device_properties(i).name}")
            print(f"  Total: {total:.2f} GB")
            print(f"  Reserved: {reserved:.2f} GB")
            print(f"  Allocated: {allocated:.2f} GB")
            print(f"  Free (in reserved): {free:.2f} GB")
    else:
        print("CUDA is not available")

def main():
    print("=" * 60)
    print("모델 로드 전 GPU 메모리 상태")
    print("=" * 60)
    print_gpu_memory()

    print("\n" + "=" * 60)
    print("Chandra 모델 로드 중...")
    print("=" * 60)

    from chandra.model.hf import load_model

    model = load_model()

    print("\n" + "=" * 60)
    print("모델 로드 후 GPU 메모리 상태")
    print("=" * 60)
    print_gpu_memory()

    # 모델 파라미터 수 계산
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("\n" + "=" * 60)
    print("모델 정보")
    print("=" * 60)
    print(f"총 파라미터 수: {total_params:,} ({total_params / 1e9:.2f}B)")
    print(f"학습 가능 파라미터: {trainable_params:,}")
    print(f"데이터 타입: {next(model.parameters()).dtype}")

    # 메모리 정리
    del model
    gc.collect()
    torch.cuda.empty_cache()

    print("\n" + "=" * 60)
    print("모델 언로드 후 GPU 메모리 상태")
    print("=" * 60)
    print_gpu_memory()

if __name__ == "__main__":
    main()
