remove infra.md.example, infra.md is the source of truth

This commit is contained in:
Azreen Jamal
2026-03-03 03:06:13 +08:00
parent 1ad3033cc1
commit a3c6d09350
86 changed files with 17093 additions and 39 deletions

View File

@@ -0,0 +1,917 @@
"""Core scan engine for AYN Antivirus.
Orchestrates file-system, process, and network scanning by delegating to
pluggable detectors (hash lookup, YARA, heuristic) and emitting events via
the :pymod:`event_bus`.
"""
from __future__ import annotations
import logging
import os
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol
from ayn_antivirus.config import Config
from ayn_antivirus.core.event_bus import EventType, event_bus
from ayn_antivirus.utils.helpers import hash_file as _hash_file_util
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class ThreatType(Enum):
"""Classification of a detected threat."""
VIRUS = auto()
MALWARE = auto()
SPYWARE = auto()
MINER = auto()
ROOTKIT = auto()
class Severity(Enum):
"""Threat severity level, ordered low → critical."""
LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4
class ScanType(Enum):
"""Kind of scan that was executed."""
FULL = "full"
QUICK = "quick"
DEEP = "deep"
SINGLE_FILE = "single_file"
TARGETED = "targeted"
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass
class ThreatInfo:
"""A single threat detected during a file scan."""
path: str
threat_name: str
threat_type: ThreatType
severity: Severity
detector_name: str
details: str = ""
timestamp: datetime = field(default_factory=datetime.utcnow)
file_hash: str = ""
@dataclass
class FileScanResult:
"""Result of scanning a single file."""
path: str
scanned: bool = True
file_hash: str = ""
size: int = 0
threats: List[ThreatInfo] = field(default_factory=list)
error: Optional[str] = None
@property
def is_clean(self) -> bool:
return len(self.threats) == 0 and self.error is None
@dataclass
class ProcessThreat:
"""A suspicious process discovered at runtime."""
pid: int
name: str
cmdline: str
cpu_percent: float
memory_percent: float
threat_type: ThreatType
severity: Severity
details: str = ""
@dataclass
class NetworkThreat:
"""A suspicious network connection."""
local_addr: str
remote_addr: str
pid: Optional[int]
process_name: str
threat_type: ThreatType
severity: Severity
details: str = ""
@dataclass
class ScanResult:
"""Aggregated result of a path / multi-file scan."""
scan_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
start_time: datetime = field(default_factory=datetime.utcnow)
end_time: Optional[datetime] = None
files_scanned: int = 0
files_skipped: int = 0
threats: List[ThreatInfo] = field(default_factory=list)
scan_path: str = ""
scan_type: ScanType = ScanType.FULL
@property
def duration_seconds(self) -> float:
if self.end_time is None:
return 0.0
return (self.end_time - self.start_time).total_seconds()
@property
def is_clean(self) -> bool:
return len(self.threats) == 0
@dataclass
class ProcessScanResult:
"""Aggregated result of a process scan."""
processes_scanned: int = 0
threats: List[ProcessThreat] = field(default_factory=list)
scan_duration: float = 0.0
@property
def total_processes(self) -> int:
"""Alias for processes_scanned (backward compat)."""
return self.processes_scanned
@property
def is_clean(self) -> bool:
return len(self.threats) == 0
@dataclass
class NetworkScanResult:
"""Aggregated result of a network scan."""
connections_scanned: int = 0
threats: List[NetworkThreat] = field(default_factory=list)
scan_duration: float = 0.0
@property
def total_connections(self) -> int:
"""Alias for connections_scanned (backward compat)."""
return self.connections_scanned
@property
def is_clean(self) -> bool:
return len(self.threats) == 0
@dataclass
class FullScanResult:
"""Combined results from a full scan (files + processes + network + containers)."""
file_scan: ScanResult = field(default_factory=ScanResult)
process_scan: ProcessScanResult = field(default_factory=ProcessScanResult)
network_scan: NetworkScanResult = field(default_factory=NetworkScanResult)
container_scan: Any = None # Optional[ContainerScanResult]
@property
def total_threats(self) -> int:
count = (
len(self.file_scan.threats)
+ len(self.process_scan.threats)
+ len(self.network_scan.threats)
)
if self.container_scan is not None:
count += len(self.container_scan.threats)
return count
@property
def is_clean(self) -> bool:
return self.total_threats == 0
# ---------------------------------------------------------------------------
# Detector protocol (for type hints & documentation)
# ---------------------------------------------------------------------------
class _Detector(Protocol):
"""Any object with a ``detect()`` method matching the BaseDetector API."""
def detect(
self,
file_path: str | Path,
file_content: Optional[bytes] = None,
file_hash: Optional[str] = None,
) -> list: ...
# ---------------------------------------------------------------------------
# Helper: file hashing
# ---------------------------------------------------------------------------
def _hash_file(filepath: Path, algo: str = "sha256") -> str:
"""Return the hex digest of *filepath*.
Delegates to :func:`ayn_antivirus.utils.helpers.hash_file`.
"""
return _hash_file_util(filepath, algo)
# ---------------------------------------------------------------------------
# Detector result → engine dataclass mapping
# ---------------------------------------------------------------------------
_THREAT_TYPE_MAP = {
"VIRUS": ThreatType.VIRUS,
"MALWARE": ThreatType.MALWARE,
"SPYWARE": ThreatType.SPYWARE,
"MINER": ThreatType.MINER,
"ROOTKIT": ThreatType.ROOTKIT,
"HEURISTIC": ThreatType.MALWARE,
}
_SEVERITY_MAP = {
"CRITICAL": Severity.CRITICAL,
"HIGH": Severity.HIGH,
"MEDIUM": Severity.MEDIUM,
"LOW": Severity.LOW,
}
def _map_threat_type(raw: str) -> ThreatType:
"""Convert a detector's threat-type string to :class:`ThreatType`."""
return _THREAT_TYPE_MAP.get(raw.upper(), ThreatType.MALWARE)
def _map_severity(raw: str) -> Severity:
"""Convert a detector's severity string to :class:`Severity`."""
return _SEVERITY_MAP.get(raw.upper(), Severity.MEDIUM)
# ---------------------------------------------------------------------------
# Quick-scan target directories
# ---------------------------------------------------------------------------
QUICK_SCAN_PATHS = [
"/tmp",
"/var/tmp",
"/dev/shm",
"/usr/local/bin",
"/var/spool/cron",
"/etc/cron.d",
"/etc/cron.daily",
"/etc/crontab",
"/var/www",
"/srv",
]
# ---------------------------------------------------------------------------
# ScanEngine
# ---------------------------------------------------------------------------
class ScanEngine:
"""Central orchestrator for all AYN scanning activities.
The engine walks the file system, delegates to pluggable detectors, tracks
statistics, and publishes events on the global :pydata:`event_bus`.
Parameters
----------
config:
Application configuration instance.
max_workers:
Thread pool size for parallel file scanning. Defaults to
``min(os.cpu_count(), 8)``.
"""
def __init__(self, config: Config, max_workers: int | None = None) -> None:
self.config = config
self.max_workers = max_workers or min(os.cpu_count() or 4, 8)
# Detector registry — populated by external plug-ins via register_detector().
# Each detector is a callable: (filepath: Path, cfg: Config) -> List[ThreatInfo]
self._detectors: List[_Detector] = []
self._init_builtin_detectors()
# ------------------------------------------------------------------
# Detector registration
# ------------------------------------------------------------------
def register_detector(self, detector: _Detector) -> None:
"""Add a detector to the scanning pipeline."""
self._detectors.append(detector)
def _init_builtin_detectors(self) -> None:
"""Register all built-in detection engines."""
from ayn_antivirus.detectors.signature_detector import SignatureDetector
from ayn_antivirus.detectors.heuristic_detector import HeuristicDetector
from ayn_antivirus.detectors.cryptominer_detector import CryptominerDetector
from ayn_antivirus.detectors.spyware_detector import SpywareDetector
from ayn_antivirus.detectors.rootkit_detector import RootkitDetector
try:
sig_det = SignatureDetector(db_path=self.config.db_path)
self.register_detector(sig_det)
except Exception as e:
logger.warning("Failed to load SignatureDetector: %s", e)
try:
self.register_detector(HeuristicDetector())
except Exception as e:
logger.warning("Failed to load HeuristicDetector: %s", e)
try:
self.register_detector(CryptominerDetector())
except Exception as e:
logger.warning("Failed to load CryptominerDetector: %s", e)
try:
self.register_detector(SpywareDetector())
except Exception as e:
logger.warning("Failed to load SpywareDetector: %s", e)
try:
self.register_detector(RootkitDetector())
except Exception as e:
logger.warning("Failed to load RootkitDetector: %s", e)
if self.config.enable_yara:
try:
from ayn_antivirus.detectors.yara_detector import YaraDetector
yara_det = YaraDetector()
self.register_detector(yara_det)
except Exception as e:
logger.debug("YARA detector not available: %s", e)
logger.info("Registered %d detectors", len(self._detectors))
# ------------------------------------------------------------------
# File scanning
# ------------------------------------------------------------------
def scan_file(self, filepath: str | Path) -> FileScanResult:
"""Scan a single file through every registered detector.
Parameters
----------
filepath:
Absolute or relative path to the file.
Returns
-------
FileScanResult
"""
filepath = Path(filepath)
result = FileScanResult(path=str(filepath))
if not filepath.is_file():
result.scanned = False
result.error = "Not a file or does not exist"
return result
try:
stat = filepath.stat()
except OSError as exc:
result.scanned = False
result.error = str(exc)
return result
result.size = stat.st_size
if result.size > self.config.max_file_size:
result.scanned = False
result.error = f"File exceeds max size ({result.size} > {self.config.max_file_size})"
return result
# Hash the file — needed by hash-based detectors and for recording.
try:
result.file_hash = _hash_file(filepath)
except OSError as exc:
result.scanned = False
result.error = f"Cannot read file: {exc}"
return result
# Enrich with FileScanner metadata (type classification).
try:
from ayn_antivirus.scanners.file_scanner import FileScanner
file_scanner = FileScanner(max_file_size=self.config.max_file_size)
file_info = file_scanner.scan(str(filepath))
result._file_info = file_info # type: ignore[attr-defined]
except Exception:
logger.debug("FileScanner enrichment skipped for %s", filepath)
# Run every registered detector.
for detector in self._detectors:
try:
detections = detector.detect(filepath, file_hash=result.file_hash)
for d in detections:
threat = ThreatInfo(
path=str(filepath),
threat_name=d.threat_name,
threat_type=_map_threat_type(d.threat_type),
severity=_map_severity(d.severity),
detector_name=d.detector_name,
details=d.details,
file_hash=result.file_hash,
)
result.threats.append(threat)
except Exception:
logger.exception("Detector %r failed on %s", detector, filepath)
# Publish per-file events.
event_bus.publish(EventType.FILE_SCANNED, result)
if result.threats:
for threat in result.threats:
event_bus.publish(EventType.THREAT_FOUND, threat)
return result
# ------------------------------------------------------------------
# Path scanning (recursive)
# ------------------------------------------------------------------
def scan_path(
self,
path: str | Path,
recursive: bool = True,
quick: bool = False,
callback: Optional[Callable[[FileScanResult], None]] = None,
) -> ScanResult:
"""Walk *path* and scan every eligible file.
Parameters
----------
path:
Root directory (or single file) to scan.
recursive:
Descend into subdirectories.
quick:
If ``True``, only scan :pydata:`QUICK_SCAN_PATHS` that exist
under *path* (or the quick-scan list itself when *path* is ``/``).
callback:
Optional function called after each file is scanned — useful for
progress reporting.
Returns
-------
ScanResult
"""
scan_type = ScanType.QUICK if quick else ScanType.FULL
result = ScanResult(
scan_path=str(path),
scan_type=scan_type,
start_time=datetime.utcnow(),
)
event_bus.publish(EventType.SCAN_STARTED, {
"scan_id": result.scan_id,
"scan_type": scan_type.value,
"path": str(path),
})
# Collect files to scan.
files = self._collect_files(Path(path), recursive=recursive, quick=quick)
# Parallel scan.
with ThreadPoolExecutor(max_workers=self.max_workers) as pool:
futures = {pool.submit(self.scan_file, fp): fp for fp in files}
for future in as_completed(futures):
try:
file_result = future.result()
except Exception:
result.files_skipped += 1
logger.exception("Unhandled error scanning %s", futures[future])
continue
if file_result.scanned:
result.files_scanned += 1
else:
result.files_skipped += 1
result.threats.extend(file_result.threats)
if callback is not None:
try:
callback(file_result)
except Exception:
logger.exception("Scan callback raised an exception")
result.end_time = datetime.utcnow()
event_bus.publish(EventType.SCAN_COMPLETED, {
"scan_id": result.scan_id,
"files_scanned": result.files_scanned,
"threats": len(result.threats),
"duration": result.duration_seconds,
})
return result
# ------------------------------------------------------------------
# Process scanning
# ------------------------------------------------------------------
def scan_processes(self) -> ProcessScanResult:
"""Inspect all running processes for known miners and anomalies.
Delegates to :class:`~ayn_antivirus.scanners.process_scanner.ProcessScanner`
for detection and converts results to engine dataclasses.
Returns
-------
ProcessScanResult
"""
from ayn_antivirus.scanners.process_scanner import ProcessScanner
result = ProcessScanResult()
start = time.monotonic()
proc_scanner = ProcessScanner()
scan_data = proc_scanner.scan()
result.processes_scanned = scan_data.get("total", 0)
# Known miner matches.
for s in scan_data.get("suspicious", []):
threat = ProcessThreat(
pid=s["pid"],
name=s.get("name", ""),
cmdline=" ".join(s.get("cmdline") or []),
cpu_percent=s.get("cpu_percent", 0.0),
memory_percent=0.0,
threat_type=ThreatType.MINER,
severity=Severity.CRITICAL,
details=s.get("reason", "Known miner process"),
)
result.threats.append(threat)
event_bus.publish(EventType.THREAT_FOUND, threat)
# High-CPU anomalies (skip duplicates already caught as miners).
miner_pids = {t.pid for t in result.threats}
for h in scan_data.get("high_cpu", []):
if h["pid"] in miner_pids:
continue
threat = ProcessThreat(
pid=h["pid"],
name=h.get("name", ""),
cmdline=" ".join(h.get("cmdline") or []),
cpu_percent=h.get("cpu_percent", 0.0),
memory_percent=0.0,
threat_type=ThreatType.MINER,
severity=Severity.HIGH,
details=h.get("reason", "Abnormally high CPU usage"),
)
result.threats.append(threat)
event_bus.publish(EventType.THREAT_FOUND, threat)
# Hidden processes (possible rootkit).
for hp in scan_data.get("hidden", []):
threat = ProcessThreat(
pid=hp["pid"],
name=hp.get("name", ""),
cmdline=hp.get("cmdline", ""),
cpu_percent=0.0,
memory_percent=0.0,
threat_type=ThreatType.ROOTKIT,
severity=Severity.CRITICAL,
details=hp.get("reason", "Hidden process"),
)
result.threats.append(threat)
event_bus.publish(EventType.THREAT_FOUND, threat)
# Optional memory scan for suspicious PIDs.
try:
from ayn_antivirus.scanners.memory_scanner import MemoryScanner
mem_scanner = MemoryScanner()
suspicious_pids = {t.pid for t in result.threats}
for pid in suspicious_pids:
try:
mem_result = mem_scanner.scan(pid)
rwx_regions = mem_result.get("rwx_regions") or []
if rwx_regions:
result.threats.append(ProcessThreat(
pid=pid,
name="",
cmdline="",
cpu_percent=0.0,
memory_percent=0.0,
threat_type=ThreatType.ROOTKIT,
severity=Severity.HIGH,
details=(
f"Injected code detected in PID {pid}: "
f"{len(rwx_regions)} RWX region(s)"
),
))
except Exception:
pass # Memory scan for individual PID is best-effort
except Exception as exc:
logger.debug("Memory scan skipped: %s", exc)
result.scan_duration = time.monotonic() - start
return result
# ------------------------------------------------------------------
# Network scanning
# ------------------------------------------------------------------
def scan_network(self) -> NetworkScanResult:
"""Scan active network connections for mining pool traffic.
Delegates to :class:`~ayn_antivirus.scanners.network_scanner.NetworkScanner`
for detection and converts results to engine dataclasses.
Returns
-------
NetworkScanResult
"""
from ayn_antivirus.scanners.network_scanner import NetworkScanner
result = NetworkScanResult()
start = time.monotonic()
net_scanner = NetworkScanner()
scan_data = net_scanner.scan()
result.connections_scanned = scan_data.get("total", 0)
# Suspicious connections (mining pools, suspicious ports).
for s in scan_data.get("suspicious", []):
sev = _map_severity(s.get("severity", "HIGH"))
threat = NetworkThreat(
local_addr=s.get("local_addr", "?"),
remote_addr=s.get("remote_addr", "?"),
pid=s.get("pid"),
process_name=(s.get("process", {}) or {}).get("name", ""),
threat_type=ThreatType.MINER,
severity=sev,
details=s.get("reason", "Suspicious connection"),
)
result.threats.append(threat)
event_bus.publish(EventType.THREAT_FOUND, threat)
# Unexpected listening ports.
for lp in scan_data.get("unexpected_listeners", []):
threat = NetworkThreat(
local_addr=lp.get("local_addr", f"?:{lp.get('port', '?')}"),
remote_addr="",
pid=lp.get("pid"),
process_name=lp.get("process_name", ""),
threat_type=ThreatType.MALWARE,
severity=_map_severity(lp.get("severity", "MEDIUM")),
details=lp.get("reason", "Unexpected listener"),
)
result.threats.append(threat)
event_bus.publish(EventType.THREAT_FOUND, threat)
# Enrich with IOC database lookups — flag connections to known-bad IPs.
try:
from ayn_antivirus.signatures.db.ioc_db import IOCDatabase
ioc_db = IOCDatabase(self.config.db_path)
ioc_db.initialize()
malicious_ips = ioc_db.get_all_malicious_ips()
if malicious_ips:
import psutil as _psutil
already_flagged = {
t.remote_addr for t in result.threats
}
try:
for conn in _psutil.net_connections(kind="inet"):
if not conn.raddr:
continue
remote_ip = conn.raddr.ip
remote_str = f"{remote_ip}:{conn.raddr.port}"
if remote_ip in malicious_ips and remote_str not in already_flagged:
ioc_info = ioc_db.lookup_ip(remote_ip) or {}
result.threats.append(NetworkThreat(
local_addr=(
f"{conn.laddr.ip}:{conn.laddr.port}"
if conn.laddr else ""
),
remote_addr=remote_str,
pid=conn.pid or 0,
process_name=self._get_proc_name(conn.pid),
threat_type=ThreatType.MALWARE,
severity=Severity.CRITICAL,
details=(
f"Connection to known malicious IP {remote_ip} "
f"(threat: {ioc_info.get('threat_name', 'IOC match')})"
),
))
except (_psutil.AccessDenied, OSError):
pass
ioc_db.close()
except Exception as exc:
logger.debug("IOC network enrichment skipped: %s", exc)
result.scan_duration = time.monotonic() - start
return result
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _get_proc_name(pid: int) -> str:
"""Best-effort process name lookup for a PID."""
if not pid:
return ""
try:
import psutil as _ps
return _ps.Process(pid).name()
except Exception:
return ""
# ------------------------------------------------------------------
# Container scanning
# ------------------------------------------------------------------
def scan_containers(
self,
runtime: str = "all",
container_id: Optional[str] = None,
):
"""Scan containers for threats.
Parameters
----------
runtime:
Container runtime to target (``"all"``, ``"docker"``,
``"podman"``, ``"lxc"``).
container_id:
If provided, scan only this specific container.
Returns
-------
ContainerScanResult
"""
from ayn_antivirus.scanners.container_scanner import ContainerScanner
scanner = ContainerScanner()
if container_id:
return scanner.scan_container(container_id)
return scanner.scan(runtime)
# ------------------------------------------------------------------
# Composite scans
# ------------------------------------------------------------------
def full_scan(
self,
callback: Optional[Callable[[FileScanResult], None]] = None,
) -> FullScanResult:
"""Run a complete scan: files, processes, and network.
Parameters
----------
callback:
Optional per-file progress callback.
Returns
-------
FullScanResult
"""
full = FullScanResult()
# File scan across all configured paths.
aggregate = ScanResult(scan_type=ScanType.FULL, start_time=datetime.utcnow())
for scan_path in self.config.scan_paths:
partial = self.scan_path(scan_path, recursive=True, quick=False, callback=callback)
aggregate.files_scanned += partial.files_scanned
aggregate.files_skipped += partial.files_skipped
aggregate.threats.extend(partial.threats)
aggregate.end_time = datetime.utcnow()
full.file_scan = aggregate
# Process + network.
full.process_scan = self.scan_processes()
full.network_scan = self.scan_network()
# Containers (best-effort — skipped if no runtimes available).
try:
container_result = self.scan_containers()
if container_result.containers_found > 0:
full.container_scan = container_result
except Exception:
logger.debug("Container scanning skipped", exc_info=True)
return full
def quick_scan(
self,
callback: Optional[Callable[[FileScanResult], None]] = None,
) -> ScanResult:
"""Scan only high-risk directories.
Targets :pydata:`QUICK_SCAN_PATHS` and any additional web roots
or crontab locations.
Returns
-------
ScanResult
"""
aggregate = ScanResult(scan_type=ScanType.QUICK, start_time=datetime.utcnow())
event_bus.publish(EventType.SCAN_STARTED, {
"scan_id": aggregate.scan_id,
"scan_type": "quick",
"paths": QUICK_SCAN_PATHS,
})
for scan_path in QUICK_SCAN_PATHS:
p = Path(scan_path)
if not p.exists():
continue
partial = self.scan_path(scan_path, recursive=True, quick=False, callback=callback)
aggregate.files_scanned += partial.files_scanned
aggregate.files_skipped += partial.files_skipped
aggregate.threats.extend(partial.threats)
aggregate.end_time = datetime.utcnow()
event_bus.publish(EventType.SCAN_COMPLETED, {
"scan_id": aggregate.scan_id,
"files_scanned": aggregate.files_scanned,
"threats": len(aggregate.threats),
"duration": aggregate.duration_seconds,
})
return aggregate
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _collect_files(
self,
root: Path,
recursive: bool = True,
quick: bool = False,
) -> List[Path]:
"""Walk *root* and return a list of scannable file paths.
Respects ``config.exclude_paths`` and ``config.max_file_size``.
"""
targets: List[Path] = []
if quick:
# In quick mode, only descend into known-risky subdirectories.
roots = [
root / rel
for rel in (
"tmp", "var/tmp", "dev/shm", "usr/local/bin",
"var/spool/cron", "etc/cron.d", "etc/cron.daily",
"var/www", "srv",
)
if (root / rel).exists()
]
# Also include the quick-scan list itself if root is /.
if str(root) == "/":
roots = [Path(p) for p in QUICK_SCAN_PATHS if Path(p).exists()]
else:
roots = [root]
exclude = set(self.config.exclude_paths)
for r in roots:
if r.is_file():
targets.append(r)
continue
iterator = r.rglob("*") if recursive else r.iterdir()
try:
for entry in iterator:
if not entry.is_file():
continue
# Exclude check.
entry_str = str(entry)
if any(entry_str.startswith(ex) for ex in exclude):
continue
try:
if entry.stat().st_size > self.config.max_file_size:
continue
except OSError:
continue
targets.append(entry)
except PermissionError:
logger.warning("Permission denied: %s", r)
return targets

View File

@@ -0,0 +1,119 @@
"""Simple publish/subscribe event bus for AYN Antivirus.
Decouples the scan engine from consumers like the CLI, logger, quarantine
manager, and real-time monitor so each component can react to events
independently.
"""
from __future__ import annotations
import logging
import threading
from enum import Enum, auto
from typing import Any, Callable, Dict, List
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Event types
# ---------------------------------------------------------------------------
class EventType(Enum):
"""All events emitted by the AYN engine."""
THREAT_FOUND = auto()
SCAN_STARTED = auto()
SCAN_COMPLETED = auto()
FILE_SCANNED = auto()
SIGNATURE_UPDATED = auto()
QUARANTINE_ACTION = auto()
REMEDIATION_ACTION = auto()
DASHBOARD_METRIC = auto()
# Type alias for subscriber callbacks.
Callback = Callable[[EventType, Any], None]
# ---------------------------------------------------------------------------
# EventBus
# ---------------------------------------------------------------------------
class EventBus:
"""Thread-safe publish/subscribe event bus.
Usage::
bus = EventBus()
bus.subscribe(EventType.THREAT_FOUND, lambda et, data: print(data))
bus.publish(EventType.THREAT_FOUND, {"path": "/tmp/evil.elf"})
"""
def __init__(self) -> None:
self._subscribers: Dict[EventType, List[Callback]] = {et: [] for et in EventType}
self._lock = threading.Lock()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def subscribe(self, event_type: EventType, callback: Callback) -> None:
"""Register *callback* to be invoked whenever *event_type* is published.
Parameters
----------
event_type:
The event to listen for.
callback:
A callable with signature ``(event_type, data) -> None``.
"""
with self._lock:
if callback not in self._subscribers[event_type]:
self._subscribers[event_type].append(callback)
def unsubscribe(self, event_type: EventType, callback: Callback) -> None:
"""Remove a previously-registered callback."""
with self._lock:
try:
self._subscribers[event_type].remove(callback)
except ValueError:
pass
def publish(self, event_type: EventType, data: Any = None) -> None:
"""Emit an event, invoking all registered callbacks synchronously.
Exceptions raised by individual callbacks are logged and swallowed so
that one faulty subscriber cannot break the pipeline.
Parameters
----------
event_type:
The event being emitted.
data:
Arbitrary payload — typically a dataclass or dict.
"""
with self._lock:
callbacks = list(self._subscribers[event_type])
for cb in callbacks:
try:
cb(event_type, data)
except Exception:
logger.exception(
"Subscriber %r raised an exception for event %s",
cb,
event_type.name,
)
def clear(self, event_type: EventType | None = None) -> None:
"""Remove all subscribers for *event_type*, or all subscribers if ``None``."""
with self._lock:
if event_type is None:
for et in EventType:
self._subscribers[et].clear()
else:
self._subscribers[event_type].clear()
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
event_bus = EventBus()

View File

@@ -0,0 +1,215 @@
"""Scheduler for recurring scans and signature updates.
Wraps the ``schedule`` library to provide cron-like recurring tasks that
drive the :class:`ScanEngine` and signature updater in a long-running
daemon loop.
"""
from __future__ import annotations
import logging
import time
from typing import Optional
import schedule
from ayn_antivirus.config import Config
from ayn_antivirus.core.engine import ScanEngine, ScanResult
from ayn_antivirus.core.event_bus import EventType, event_bus
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Cron expression helpers
# ---------------------------------------------------------------------------
def _parse_cron_field(field: str, min_val: int, max_val: int) -> list[int]:
"""Parse a single cron field (e.g. ``*/5``, ``1,3,5``, ``0-23``, ``*``).
Returns a sorted list of matching integer values.
"""
values: set[int] = set()
for part in field.split(","):
part = part.strip()
# */step
if part.startswith("*/"):
step = int(part[2:])
values.update(range(min_val, max_val + 1, step))
# range with optional step (e.g. 1-5 or 1-5/2)
elif "-" in part:
range_part, _, step_part = part.partition("/")
lo, hi = range_part.split("-", 1)
step = int(step_part) if step_part else 1
values.update(range(int(lo), int(hi) + 1, step))
# wildcard
elif part == "*":
values.update(range(min_val, max_val + 1))
# literal
else:
values.add(int(part))
return sorted(values)
def _cron_to_schedule(cron_expr: str) -> dict:
"""Convert a 5-field cron expression into components.
Returns a dict with keys ``minutes``, ``hours``, ``days``, ``months``,
``weekdays`` — each a list of integers.
Only *minute* and *hour* are used by the ``schedule`` library adapter
below; the rest are validated but not fully honoured (``schedule`` lacks
calendar-level granularity).
"""
parts = cron_expr.strip().split()
if len(parts) != 5:
raise ValueError(f"Expected 5-field cron expression, got: {cron_expr!r}")
return {
"minutes": _parse_cron_field(parts[0], 0, 59),
"hours": _parse_cron_field(parts[1], 0, 23),
"days": _parse_cron_field(parts[2], 1, 31),
"months": _parse_cron_field(parts[3], 1, 12),
"weekdays": _parse_cron_field(parts[4], 0, 6),
}
# ---------------------------------------------------------------------------
# Scheduler
# ---------------------------------------------------------------------------
class Scheduler:
"""Manages recurring scan and update jobs.
Parameters
----------
config:
Application configuration — used to build a :class:`ScanEngine` and
read schedule expressions.
engine:
Optional pre-built engine instance. If ``None``, one is created from
*config*.
"""
def __init__(self, config: Config, engine: Optional[ScanEngine] = None) -> None:
self.config = config
self.engine = engine or ScanEngine(config)
self._scheduler = schedule.Scheduler()
# ------------------------------------------------------------------
# Job builders
# ------------------------------------------------------------------
def schedule_scan(self, cron_expr: str, scan_type: str = "full") -> None:
"""Schedule a recurring scan using a cron expression.
Parameters
----------
cron_expr:
Standard 5-field cron string (``minute hour dom month dow``).
scan_type:
One of ``"full"``, ``"quick"``, or ``"deep"``.
"""
parsed = _cron_to_schedule(cron_expr)
# ``schedule`` doesn't natively support cron, so we approximate by
# scheduling at every matching hour:minute combination. For simple
# expressions like ``0 2 * * *`` this is exact.
for hour in parsed["hours"]:
for minute in parsed["minutes"]:
time_str = f"{hour:02d}:{minute:02d}"
self._scheduler.every().day.at(time_str).do(
self._run_scan, scan_type=scan_type
)
logger.info("Scheduled %s scan at %s daily", scan_type, time_str)
def schedule_update(self, interval_hours: int = 6) -> None:
"""Schedule recurring signature updates.
Parameters
----------
interval_hours:
How often (in hours) to pull fresh signatures.
"""
self._scheduler.every(interval_hours).hours.do(self._run_update)
logger.info("Scheduled signature update every %d hour(s)", interval_hours)
# ------------------------------------------------------------------
# Daemon loop
# ------------------------------------------------------------------
def run_daemon(self) -> None:
"""Start the blocking scheduler loop.
Runs all pending jobs and sleeps between iterations. Designed to be
the main loop of a background daemon process.
Press ``Ctrl+C`` (or send ``SIGINT``) to exit cleanly.
"""
logger.info("AYN scheduler daemon started — %d job(s)", len(self._scheduler.get_jobs()))
try:
while True:
self._scheduler.run_pending()
time.sleep(30)
except KeyboardInterrupt:
logger.info("Scheduler daemon stopped by user")
# ------------------------------------------------------------------
# Job implementations
# ------------------------------------------------------------------
def _run_scan(self, scan_type: str = "full") -> None:
"""Execute a scan job."""
logger.info("Starting scheduled %s scan", scan_type)
try:
if scan_type == "quick":
result: ScanResult = self.engine.quick_scan()
else:
# "full" and "deep" both scan all paths; deep adds process/network
# via full_scan on the engine, but here we keep it simple.
result = ScanResult()
for path in self.config.scan_paths:
partial = self.engine.scan_path(path, recursive=True)
result.files_scanned += partial.files_scanned
result.files_skipped += partial.files_skipped
result.threats.extend(partial.threats)
logger.info(
"Scheduled %s scan complete — %d files, %d threats",
scan_type,
result.files_scanned,
len(result.threats),
)
except Exception:
logger.exception("Scheduled %s scan failed", scan_type)
def _run_update(self) -> None:
"""Execute a signature update job."""
logger.info("Starting scheduled signature update")
try:
from ayn_antivirus.signatures.manager import SignatureManager
manager = SignatureManager(self.config)
summary = manager.update_all()
total = summary.get("total_new", 0)
errors = summary.get("errors", [])
logger.info(
"Scheduled signature update complete: %d new, %d errors",
total,
len(errors),
)
if errors:
for err in errors:
logger.warning("Feed error: %s", err)
manager.close()
event_bus.publish(EventType.SIGNATURE_UPDATED, {
"total_new": total,
"feeds": list(summary.get("feeds", {}).keys()),
"errors": errors,
})
except Exception:
logger.exception("Scheduled signature update failed")