"""Configuration loader for AYN Antivirus.""" from __future__ import annotations import os from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional import yaml from ayn_antivirus.constants import ( DEFAULT_CONFIG_PATHS, DEFAULT_DASHBOARD_DB_PATH, DEFAULT_DASHBOARD_HOST, DEFAULT_DASHBOARD_PASSWORD, DEFAULT_DASHBOARD_PORT, DEFAULT_DASHBOARD_USERNAME, DEFAULT_DB_PATH, DEFAULT_LOG_PATH, DEFAULT_QUARANTINE_PATH, DEFAULT_SCAN_PATH, MAX_FILE_SIZE, ) @dataclass class Config: """Application configuration, loaded from YAML config files or environment variables.""" scan_paths: List[str] = field(default_factory=lambda: [DEFAULT_SCAN_PATH]) exclude_paths: List[str] = field( default_factory=lambda: ["/proc", "/sys", "/dev", "/run", "/snap"] ) quarantine_path: str = DEFAULT_QUARANTINE_PATH db_path: str = DEFAULT_DB_PATH log_path: str = DEFAULT_LOG_PATH auto_quarantine: bool = False scan_schedule: str = "0 2 * * *" api_keys: Dict[str, str] = field(default_factory=dict) max_file_size: int = MAX_FILE_SIZE enable_yara: bool = True enable_heuristics: bool = True enable_realtime_monitor: bool = False dashboard_host: str = DEFAULT_DASHBOARD_HOST dashboard_port: int = DEFAULT_DASHBOARD_PORT dashboard_db_path: str = DEFAULT_DASHBOARD_DB_PATH dashboard_username: str = DEFAULT_DASHBOARD_USERNAME dashboard_password: str = DEFAULT_DASHBOARD_PASSWORD @classmethod def load(cls, config_path: Optional[str] = None) -> Config: """Load configuration from a YAML file, then overlay environment variables. Search order: 1. Explicit ``config_path`` argument. 2. /etc/ayn-antivirus/config.yaml 3. ~/.ayn-antivirus/config.yaml 4. Environment variables (always applied last as overrides). """ data: Dict[str, Any] = {} paths_to_try = [config_path] if config_path else DEFAULT_CONFIG_PATHS for path in paths_to_try: if path and Path(path).is_file(): with open(path, "r") as fh: data = yaml.safe_load(fh) or {} break defaults = cls() config = cls( scan_paths=data.get("scan_paths", defaults.scan_paths), exclude_paths=data.get("exclude_paths", defaults.exclude_paths), quarantine_path=data.get("quarantine_path", DEFAULT_QUARANTINE_PATH), db_path=data.get("db_path", DEFAULT_DB_PATH), log_path=data.get("log_path", DEFAULT_LOG_PATH), auto_quarantine=data.get("auto_quarantine", False), scan_schedule=data.get("scan_schedule", "0 2 * * *"), api_keys=data.get("api_keys", {}), max_file_size=data.get("max_file_size", MAX_FILE_SIZE), enable_yara=data.get("enable_yara", True), enable_heuristics=data.get("enable_heuristics", True), enable_realtime_monitor=data.get("enable_realtime_monitor", False), dashboard_host=data.get("dashboard_host", DEFAULT_DASHBOARD_HOST), dashboard_port=data.get("dashboard_port", DEFAULT_DASHBOARD_PORT), dashboard_db_path=data.get("dashboard_db_path", DEFAULT_DASHBOARD_DB_PATH), dashboard_username=data.get("dashboard_username", DEFAULT_DASHBOARD_USERNAME), dashboard_password=data.get("dashboard_password", DEFAULT_DASHBOARD_PASSWORD), ) # --- Environment variable overrides --- config._apply_env_overrides() return config def _apply_env_overrides(self) -> None: """Override config fields with AYN_* environment variables when set.""" if os.getenv("AYN_SCAN_PATH"): self.scan_paths = [p.strip() for p in os.environ["AYN_SCAN_PATH"].split(",")] if os.getenv("AYN_QUARANTINE_PATH"): self.quarantine_path = os.environ["AYN_QUARANTINE_PATH"] if os.getenv("AYN_DB_PATH"): self.db_path = os.environ["AYN_DB_PATH"] if os.getenv("AYN_LOG_PATH"): self.log_path = os.environ["AYN_LOG_PATH"] if os.getenv("AYN_AUTO_QUARANTINE"): self.auto_quarantine = os.environ["AYN_AUTO_QUARANTINE"].lower() in ( "true", "1", "yes", ) if os.getenv("AYN_SCAN_SCHEDULE"): self.scan_schedule = os.environ["AYN_SCAN_SCHEDULE"] if os.getenv("AYN_MALWAREBAZAAR_API_KEY"): self.api_keys["malwarebazaar"] = os.environ["AYN_MALWAREBAZAAR_API_KEY"] if os.getenv("AYN_VIRUSTOTAL_API_KEY"): self.api_keys["virustotal"] = os.environ["AYN_VIRUSTOTAL_API_KEY"] if os.getenv("AYN_MAX_FILE_SIZE"): self.max_file_size = int(os.environ["AYN_MAX_FILE_SIZE"]) if os.getenv("AYN_DASHBOARD_HOST"): self.dashboard_host = os.environ["AYN_DASHBOARD_HOST"] if os.getenv("AYN_DASHBOARD_PORT"): self.dashboard_port = int(os.environ["AYN_DASHBOARD_PORT"]) if os.getenv("AYN_DASHBOARD_DB_PATH"): self.dashboard_db_path = os.environ["AYN_DASHBOARD_DB_PATH"] if os.getenv("AYN_DASHBOARD_USERNAME"): self.dashboard_username = os.environ["AYN_DASHBOARD_USERNAME"] if os.getenv("AYN_DASHBOARD_PASSWORD"): self.dashboard_password = os.environ["AYN_DASHBOARD_PASSWORD"]