597 lines
20 KiB
Python
597 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Image Downloader Script
|
|
|
|
This script downloads images from a REST API that provides:
|
|
1. An endpoint to list all assets
|
|
2. An endpoint to download individual assets in full resolution
|
|
|
|
Usage:
|
|
python image_downloader.py --api-url <base_url> --list-endpoint <endpoint> --download-endpoint <endpoint> --output-dir <directory>
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import aiohttp
|
|
import aiofiles
|
|
import os
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from urllib.parse import urljoin, urlparse
|
|
from typing import List, Dict, Any, Optional
|
|
import time
|
|
from tqdm import tqdm
|
|
import hashlib
|
|
|
|
# Import the auth manager and asset tracker
|
|
try:
|
|
from src.auth_manager import AuthManager
|
|
except ImportError:
|
|
AuthManager = None
|
|
|
|
try:
|
|
from src.asset_tracker import AssetTracker
|
|
except ImportError:
|
|
AssetTracker = None
|
|
|
|
|
|
class ImageDownloader:
|
|
def __init__(
|
|
self,
|
|
api_url: str,
|
|
list_endpoint: str,
|
|
download_endpoint: str,
|
|
output_dir: str,
|
|
max_concurrent: int = 5,
|
|
timeout: int = 30,
|
|
api_key: str = None,
|
|
email: str = None,
|
|
password: str = None,
|
|
track_assets: bool = True,
|
|
):
|
|
"""
|
|
Initialize the image downloader.
|
|
|
|
Args:
|
|
api_url: Base URL of the API
|
|
list_endpoint: Endpoint to get the list of assets
|
|
download_endpoint: Endpoint to download individual assets
|
|
output_dir: Directory to save downloaded images
|
|
max_concurrent: Maximum number of concurrent downloads
|
|
timeout: Request timeout in seconds
|
|
api_key: API key for authentication
|
|
email: Email for login authentication
|
|
password: Password for login authentication
|
|
track_assets: Whether to enable asset tracking to avoid re-downloads
|
|
"""
|
|
self.api_url = api_url.rstrip("/")
|
|
self.list_endpoint = list_endpoint.lstrip("/")
|
|
self.download_endpoint = download_endpoint.lstrip("/")
|
|
self.output_dir = Path(output_dir)
|
|
self.max_concurrent = max_concurrent
|
|
self.timeout = timeout
|
|
self.api_key = api_key
|
|
self.email = email
|
|
self.password = password
|
|
self.auth_manager = None
|
|
|
|
# Create output directory if it doesn't exist
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
handlers=[
|
|
logging.FileHandler(self.output_dir / "download.log"),
|
|
logging.StreamHandler(),
|
|
],
|
|
)
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
# Initialize asset tracker if enabled and available
|
|
self.asset_tracker = None
|
|
if track_assets and AssetTracker:
|
|
self.asset_tracker = AssetTracker(storage_dir=str(self.output_dir))
|
|
self.logger.info("Asset tracking enabled")
|
|
elif track_assets:
|
|
self.logger.warning(
|
|
"Asset tracking requested but AssetTracker not available"
|
|
)
|
|
else:
|
|
self.logger.info("Asset tracking disabled")
|
|
|
|
# Track download statistics
|
|
self.stats = {"total": 0, "successful": 0, "failed": 0, "skipped": 0}
|
|
|
|
async def authenticate(self):
|
|
"""Perform login authentication if credentials are provided."""
|
|
if self.email and self.password and AuthManager:
|
|
self.logger.info("Attempting login authentication...")
|
|
self.auth_manager = AuthManager(self.api_url)
|
|
success = await self.auth_manager.login(self.email, self.password)
|
|
|
|
if success:
|
|
self.logger.info("Login authentication successful")
|
|
else:
|
|
self.logger.error("Login authentication failed")
|
|
raise Exception("Login authentication failed")
|
|
elif self.email or self.password:
|
|
self.logger.warning(
|
|
"Both email and password must be provided for login authentication"
|
|
)
|
|
raise Exception(
|
|
"Both email and password must be provided for login authentication"
|
|
)
|
|
|
|
async def get_asset_list(
|
|
self, session: aiohttp.ClientSession
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Fetch the list of assets from the API.
|
|
|
|
Args:
|
|
session: aiohttp session for making requests
|
|
|
|
Returns:
|
|
List of asset dictionaries
|
|
"""
|
|
url = urljoin(self.api_url, self.list_endpoint)
|
|
self.logger.info(f"Fetching asset list from: {url}")
|
|
|
|
try:
|
|
headers = {}
|
|
|
|
# Use API key if provided
|
|
if self.api_key:
|
|
headers["x-api-key"] = self.api_key
|
|
|
|
# Use login authentication if provided
|
|
elif self.auth_manager and self.auth_manager.is_authenticated():
|
|
headers.update(self.auth_manager.get_auth_headers())
|
|
|
|
async with session.get(
|
|
url, headers=headers, timeout=self.timeout
|
|
) as response:
|
|
response.raise_for_status()
|
|
data = await response.json()
|
|
|
|
# Handle different response formats
|
|
if isinstance(data, list):
|
|
assets = data
|
|
elif isinstance(data, dict):
|
|
# Common patterns for API responses
|
|
if "data" in data:
|
|
assets = data["data"]
|
|
elif "results" in data:
|
|
assets = data["results"]
|
|
elif "items" in data:
|
|
assets = data["items"]
|
|
else:
|
|
assets = [data] # Single asset
|
|
else:
|
|
raise ValueError(f"Unexpected response format: {type(data)}")
|
|
|
|
self.logger.info(f"Found {len(assets)} assets")
|
|
return assets
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to fetch asset list: {e}")
|
|
raise
|
|
|
|
def get_download_url(self, asset: Dict[str, Any]) -> str:
|
|
"""
|
|
Generate the download URL for an asset.
|
|
|
|
Args:
|
|
asset: Asset dictionary from the API
|
|
|
|
Returns:
|
|
Download URL for the asset
|
|
"""
|
|
# Try different common patterns for asset IDs
|
|
asset_id = None
|
|
|
|
# Common field names for asset identifiers
|
|
id_fields = ["id", "asset_id", "image_id", "file_id", "uuid", "key"]
|
|
for field in id_fields:
|
|
if field in asset:
|
|
asset_id = asset[field]
|
|
break
|
|
|
|
if asset_id is None:
|
|
# If no ID field found, try to use the asset itself as the ID
|
|
asset_id = str(asset)
|
|
|
|
# Build download URL with required parameters
|
|
from urllib.parse import urlencode
|
|
|
|
params = {"key": self.api_key, "u": asset.get("updated", "")}
|
|
|
|
download_url = urljoin(
|
|
self.api_url, f"/v1/media/{asset_id}/full?{urlencode(params)}"
|
|
)
|
|
return download_url
|
|
|
|
def get_filename(self, asset: Dict[str, Any], url: str) -> str:
|
|
"""
|
|
Generate a filename for the downloaded asset.
|
|
|
|
Args:
|
|
asset: Asset dictionary from the API
|
|
url: Download URL
|
|
|
|
Returns:
|
|
Filename for the asset
|
|
"""
|
|
# Try to get filename from asset metadata
|
|
if "fileName" in asset:
|
|
filename = asset["fileName"]
|
|
elif "filename" in asset:
|
|
filename = asset["filename"]
|
|
elif "name" in asset:
|
|
filename = asset["name"]
|
|
elif "title" in asset:
|
|
filename = asset["title"]
|
|
else:
|
|
# Extract filename from URL
|
|
parsed_url = urlparse(url)
|
|
filename = os.path.basename(parsed_url.path)
|
|
|
|
# If no extension, try to get it from content-type or add default
|
|
if "." not in filename:
|
|
if "mimeType" in asset:
|
|
ext = self._get_extension_from_mime(asset["mimeType"])
|
|
elif "content_type" in asset:
|
|
ext = self._get_extension_from_mime(asset["content_type"])
|
|
else:
|
|
ext = ".jpg" # Default extension
|
|
filename += ext
|
|
|
|
# Sanitize filename
|
|
filename = self._sanitize_filename(filename)
|
|
|
|
# Ensure unique filename
|
|
counter = 1
|
|
original_filename = filename
|
|
while (self.output_dir / filename).exists():
|
|
name, ext = os.path.splitext(original_filename)
|
|
filename = f"{name}_{counter}{ext}"
|
|
counter += 1
|
|
|
|
return filename
|
|
|
|
def _get_extension_from_mime(self, mime_type: str) -> str:
|
|
"""Get file extension from MIME type."""
|
|
mime_to_ext = {
|
|
"image/jpeg": ".jpg",
|
|
"image/jpg": ".jpg",
|
|
"image/png": ".png",
|
|
"image/gif": ".gif",
|
|
"image/webp": ".webp",
|
|
"image/bmp": ".bmp",
|
|
"image/tiff": ".tiff",
|
|
"image/svg+xml": ".svg",
|
|
}
|
|
return mime_to_ext.get(mime_type.lower(), ".jpg")
|
|
|
|
def _sanitize_filename(self, filename: str) -> str:
|
|
"""Sanitize filename by removing invalid characters."""
|
|
# Remove or replace invalid characters
|
|
invalid_chars = '<>:"/\\|?*'
|
|
for char in invalid_chars:
|
|
filename = filename.replace(char, "_")
|
|
|
|
# Remove leading/trailing spaces and dots
|
|
filename = filename.strip(". ")
|
|
|
|
# Ensure filename is not empty
|
|
if not filename:
|
|
filename = "image"
|
|
|
|
return filename
|
|
|
|
async def download_asset(
|
|
self,
|
|
session: aiohttp.ClientSession,
|
|
asset: Dict[str, Any],
|
|
semaphore: asyncio.Semaphore,
|
|
) -> bool:
|
|
"""
|
|
Download a single asset.
|
|
|
|
Args:
|
|
session: aiohttp session for making requests
|
|
asset: Asset dictionary from the API
|
|
semaphore: Semaphore to limit concurrent downloads
|
|
|
|
Returns:
|
|
True if download was successful, False otherwise
|
|
"""
|
|
async with semaphore:
|
|
try:
|
|
download_url = self.get_download_url(asset)
|
|
filename = self.get_filename(asset, download_url)
|
|
filepath = self.output_dir / filename
|
|
|
|
# Check if file already exists and we're not tracking assets
|
|
if filepath.exists() and not self.asset_tracker:
|
|
self.logger.info(f"Skipping {filename} (already exists)")
|
|
self.stats["skipped"] += 1
|
|
return True
|
|
|
|
self.logger.info(f"Downloading {filename} from {download_url}")
|
|
|
|
async with session.get(download_url, timeout=self.timeout) as response:
|
|
response.raise_for_status()
|
|
|
|
# Get content type to verify it's an image
|
|
content_type = response.headers.get("content-type", "")
|
|
if not content_type.startswith("image/"):
|
|
self.logger.warning(
|
|
f"Content type is not an image: {content_type}"
|
|
)
|
|
|
|
# Download the file
|
|
async with aiofiles.open(filepath, "wb") as f:
|
|
async for chunk in response.content.iter_chunked(8192):
|
|
await f.write(chunk)
|
|
|
|
# Set file modification time to match the updated timestamp
|
|
if "updated" in asset:
|
|
try:
|
|
from datetime import datetime
|
|
import os
|
|
|
|
# Parse the ISO timestamp
|
|
updated_time = datetime.fromisoformat(
|
|
asset["updated"].replace("Z", "+00:00")
|
|
)
|
|
# Set file modification time
|
|
os.utime(
|
|
filepath,
|
|
(updated_time.timestamp(), updated_time.timestamp()),
|
|
)
|
|
self.logger.info(
|
|
f"Set file modification time to {asset['updated']}"
|
|
)
|
|
except Exception as e:
|
|
self.logger.warning(
|
|
f"Failed to set file modification time: {e}"
|
|
)
|
|
|
|
# Mark asset as downloaded in tracker
|
|
if self.asset_tracker:
|
|
self.asset_tracker.mark_asset_downloaded(asset, filepath, True)
|
|
|
|
self.logger.info(f"Successfully downloaded {filename}")
|
|
self.stats["successful"] += 1
|
|
return True
|
|
|
|
except Exception as e:
|
|
# Mark asset as failed in tracker
|
|
if self.asset_tracker:
|
|
download_url = self.get_download_url(asset)
|
|
filename = self.get_filename(asset, download_url)
|
|
filepath = self.output_dir / filename
|
|
self.asset_tracker.mark_asset_downloaded(asset, filepath, False)
|
|
|
|
self.logger.error(
|
|
f"Failed to download asset {asset.get('id', 'unknown')}: {e}"
|
|
)
|
|
self.stats["failed"] += 1
|
|
return False
|
|
|
|
async def download_all_assets(self, force_redownload: bool = False):
|
|
"""
|
|
Download all assets from the API.
|
|
|
|
Args:
|
|
force_redownload: If True, download all assets regardless of tracking
|
|
"""
|
|
start_time = time.time()
|
|
|
|
# Create aiohttp session with connection pooling
|
|
connector = aiohttp.TCPConnector(limit=100, limit_per_host=30)
|
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
|
|
|
async with aiohttp.ClientSession(
|
|
connector=connector, timeout=timeout
|
|
) as session:
|
|
try:
|
|
# Perform authentication if needed
|
|
await self.authenticate()
|
|
|
|
# Get asset list
|
|
all_assets = await self.get_asset_list(session)
|
|
self.logger.info(f"Retrieved {len(all_assets)} total assets from API")
|
|
|
|
if not all_assets:
|
|
self.logger.warning("No assets found to download")
|
|
return
|
|
|
|
# Filter for new/modified assets if tracking is enabled
|
|
if self.asset_tracker and not force_redownload:
|
|
assets = self.asset_tracker.get_new_assets(all_assets)
|
|
self.logger.info(
|
|
f"Found {len(assets)} new/modified assets to download"
|
|
)
|
|
if len(assets) == 0:
|
|
self.logger.info("All assets are up to date!")
|
|
return
|
|
else:
|
|
assets = all_assets
|
|
if force_redownload:
|
|
self.logger.info(
|
|
"Force redownload enabled - downloading all assets"
|
|
)
|
|
|
|
self.stats["total"] = len(assets)
|
|
|
|
# Create semaphore to limit concurrent downloads
|
|
semaphore = asyncio.Semaphore(self.max_concurrent)
|
|
|
|
# Create tasks for all downloads
|
|
tasks = [
|
|
self.download_asset(session, asset, semaphore) for asset in assets
|
|
]
|
|
|
|
# Download all assets with progress bar
|
|
with tqdm(total=len(tasks), desc="Downloading assets") as pbar:
|
|
for coro in asyncio.as_completed(tasks):
|
|
result = await coro
|
|
pbar.update(1)
|
|
pbar.set_postfix(
|
|
{
|
|
"Success": self.stats["successful"],
|
|
"Failed": self.stats["failed"],
|
|
"Skipped": self.stats["skipped"],
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error during download process: {e}")
|
|
raise
|
|
|
|
# Print final statistics
|
|
elapsed_time = time.time() - start_time
|
|
self.logger.info(f"Download completed in {elapsed_time:.2f} seconds")
|
|
self.logger.info(f"Statistics: {self.stats}")
|
|
|
|
|
|
def main():
|
|
"""Main function to run the image downloader."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Download images from a REST API",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Basic usage
|
|
python image_downloader.py --api-url "https://api.example.com" \\
|
|
--list-endpoint "/assets" \\
|
|
--download-endpoint "/download" \\
|
|
--output-dir "./images"
|
|
|
|
# With custom concurrent downloads and timeout
|
|
python image_downloader.py --api-url "https://api.example.com" \\
|
|
--list-endpoint "/assets" \\
|
|
--download-endpoint "/download" \\
|
|
--output-dir "./images" \\
|
|
--max-concurrent 10 \\
|
|
--timeout 60
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--api-url",
|
|
required=True,
|
|
help="Base URL of the API (e.g., https://api.example.com)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--list-endpoint",
|
|
required=True,
|
|
help="Endpoint to get the list of assets (e.g., /assets or /images)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--download-endpoint",
|
|
required=True,
|
|
help="Endpoint to download individual assets (e.g., /download or /assets)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output-dir", required=True, help="Directory to save downloaded images"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-concurrent",
|
|
type=int,
|
|
default=5,
|
|
help="Maximum number of concurrent downloads (default: 5)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--timeout",
|
|
type=int,
|
|
default=30,
|
|
help="Request timeout in seconds (default: 30)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--api-key", help="API key for authentication (x-api-key header)"
|
|
)
|
|
|
|
parser.add_argument("--email", help="Email for login authentication")
|
|
|
|
parser.add_argument("--password", help="Password for login authentication")
|
|
|
|
parser.add_argument(
|
|
"--no-tracking",
|
|
action="store_true",
|
|
help="Disable asset tracking (will re-download all assets)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--force-redownload",
|
|
action="store_true",
|
|
help="Force re-download of all assets, even if already tracked",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--show-stats",
|
|
action="store_true",
|
|
help="Show asset tracking statistics and exit",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--cleanup",
|
|
action="store_true",
|
|
help="Clean up metadata for missing files and exit",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Handle special commands first
|
|
if args.show_stats or args.cleanup:
|
|
if AssetTracker:
|
|
tracker = AssetTracker(storage_dir=args.output_dir)
|
|
if args.cleanup:
|
|
tracker.cleanup_missing_files()
|
|
if args.show_stats:
|
|
tracker.print_stats()
|
|
else:
|
|
print("Asset tracking is not available")
|
|
return
|
|
|
|
# Create the image downloader
|
|
downloader = ImageDownloader(
|
|
api_url=args.api_url,
|
|
list_endpoint=args.list_endpoint,
|
|
download_endpoint=args.download_endpoint,
|
|
output_dir=args.output_dir,
|
|
max_concurrent=args.max_concurrent,
|
|
timeout=args.timeout,
|
|
api_key=args.api_key,
|
|
email=args.email,
|
|
password=args.password,
|
|
track_assets=not args.no_tracking,
|
|
)
|
|
|
|
try:
|
|
asyncio.run(
|
|
downloader.download_all_assets(force_redownload=args.force_redownload)
|
|
)
|
|
except KeyboardInterrupt:
|
|
print("\nDownload interrupted by user")
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
return 1
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|