Files

180 lines
5.3 KiB
Python

"""General-purpose utility functions for AYN Antivirus."""
from __future__ import annotations
import hashlib
import os
import platform
import re
import socket
import uuid
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict
import psutil
from ayn_antivirus.constants import SCAN_CHUNK_SIZE
# ---------------------------------------------------------------------------
# Human-readable formatting
# ---------------------------------------------------------------------------
def format_size(size_bytes: int | float) -> str:
"""Convert bytes to a human-readable string (e.g. ``"14.2 MB"``)."""
for unit in ("B", "KB", "MB", "GB", "TB"):
if abs(size_bytes) < 1024:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024
return f"{size_bytes:.1f} PB"
def format_duration(seconds: float) -> str:
"""Convert seconds to a human-readable duration (e.g. ``"1h 23m 45s"``)."""
if seconds < 0:
return "0s"
td = timedelta(seconds=int(seconds))
parts = []
total_secs = int(td.total_seconds())
hours, rem = divmod(total_secs, 3600)
minutes, secs = divmod(rem, 60)
if hours:
parts.append(f"{hours}h")
if minutes:
parts.append(f"{minutes}m")
parts.append(f"{secs}s")
return " ".join(parts)
# ---------------------------------------------------------------------------
# Privilege check
# ---------------------------------------------------------------------------
def is_root() -> bool:
"""Return ``True`` if the current process is running as root (UID 0)."""
return os.geteuid() == 0
# ---------------------------------------------------------------------------
# System information
# ---------------------------------------------------------------------------
def get_system_info() -> Dict[str, Any]:
"""Collect hostname, OS, kernel, uptime, CPU, and memory details."""
mem = psutil.virtual_memory()
boot = psutil.boot_time()
uptime_secs = psutil.time.time() - boot
return {
"hostname": socket.gethostname(),
"os": f"{platform.system()} {platform.release()}",
"os_pretty": platform.platform(),
"kernel": platform.release(),
"architecture": platform.machine(),
"cpu_count": psutil.cpu_count(logical=True),
"cpu_physical": psutil.cpu_count(logical=False),
"cpu_percent": psutil.cpu_percent(interval=0.1),
"memory_total": mem.total,
"memory_total_human": format_size(mem.total),
"memory_available": mem.available,
"memory_available_human": format_size(mem.available),
"memory_percent": mem.percent,
"uptime_seconds": uptime_secs,
"uptime_human": format_duration(uptime_secs),
}
# ---------------------------------------------------------------------------
# Path safety
# ---------------------------------------------------------------------------
def safe_path(path: str | Path) -> Path:
"""Resolve and validate a path.
Expands ``~``, resolves symlinks, and ensures the result does not
escape above the filesystem root via ``..`` traversal.
Raises
------
ValueError
If the path is empty or contains null bytes.
"""
s = str(path).strip()
if not s:
raise ValueError("Path must not be empty")
if "\x00" in s:
raise ValueError("Path must not contain null bytes")
resolved = Path(os.path.expanduser(s)).resolve()
return resolved
# ---------------------------------------------------------------------------
# ID generation
# ---------------------------------------------------------------------------
def generate_id() -> str:
"""Return a new UUID4 hex string (32 characters, no hyphens)."""
return uuid.uuid4().hex
# ---------------------------------------------------------------------------
# File hashing
# ---------------------------------------------------------------------------
def hash_file(path: str | Path, algo: str = "sha256") -> str:
"""Return the hex digest of *path* using the specified algorithm.
Reads the file in chunks of :pydata:`SCAN_CHUNK_SIZE` for efficiency.
Parameters
----------
algo:
Any algorithm accepted by :func:`hashlib.new`.
Raises
------
OSError
If the file cannot be opened or read.
"""
h = hashlib.new(algo)
with open(path, "rb") as fh:
while True:
chunk = fh.read(SCAN_CHUNK_SIZE)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
# ---------------------------------------------------------------------------
# Validation
# ---------------------------------------------------------------------------
# Compiled once at import time.
_IPV4_RE = re.compile(
r"^(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}"
r"(?:25[0-5]|2[0-4]\d|[01]?\d\d?)$"
)
_DOMAIN_RE = re.compile(
r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+"
r"[a-zA-Z]{2,}$"
)
def validate_ip(ip: str) -> bool:
"""Return ``True`` if *ip* is a valid IPv4 address."""
return bool(_IPV4_RE.match(ip.strip()))
def validate_domain(domain: str) -> bool:
"""Return ``True`` if *domain* looks like a valid DNS domain name."""
d = domain.strip().rstrip(".")
if len(d) > 253:
return False
return bool(_DOMAIN_RE.match(d))