remove infra.md.example, infra.md is the source of truth
This commit is contained in:
320
ayn-antivirus/ayn_antivirus/signatures/manager.py
Normal file
320
ayn-antivirus/ayn_antivirus/signatures/manager.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Signature manager for AYN Antivirus.
|
||||
|
||||
Orchestrates all threat-intelligence feeds, routes fetched entries into the
|
||||
correct database (hash DB or IOC DB), and exposes high-level update /
|
||||
status / integrity operations for the CLI and scheduler.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ayn_antivirus.config import Config
|
||||
from ayn_antivirus.constants import DEFAULT_DB_PATH
|
||||
from ayn_antivirus.core.event_bus import EventType, event_bus
|
||||
from ayn_antivirus.signatures.db.hash_db import HashDatabase
|
||||
from ayn_antivirus.signatures.db.ioc_db import IOCDatabase
|
||||
from ayn_antivirus.signatures.feeds.base_feed import BaseFeed
|
||||
from ayn_antivirus.signatures.feeds.emergingthreats import EmergingThreatsFeed
|
||||
from ayn_antivirus.signatures.feeds.feodotracker import FeodoTrackerFeed
|
||||
from ayn_antivirus.signatures.feeds.malwarebazaar import MalwareBazaarFeed
|
||||
from ayn_antivirus.signatures.feeds.threatfox import ThreatFoxFeed
|
||||
from ayn_antivirus.signatures.feeds.urlhaus import URLHausFeed
|
||||
from ayn_antivirus.signatures.feeds.virusshare import VirusShareFeed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SignatureManager:
|
||||
"""Central coordinator for signature / IOC updates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config:
|
||||
Application configuration.
|
||||
db_path:
|
||||
Override the database path from config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
db_path: Optional[str | Path] = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self._db_path = Path(db_path or config.db_path)
|
||||
|
||||
# Databases.
|
||||
self.hash_db = HashDatabase(self._db_path)
|
||||
self.ioc_db = IOCDatabase(self._db_path)
|
||||
|
||||
# Feeds — instantiated lazily so missing API keys don't crash init.
|
||||
self._feeds: Dict[str, BaseFeed] = {}
|
||||
self._init_feeds()
|
||||
|
||||
# Auto-update thread handle.
|
||||
self._auto_update_stop = threading.Event()
|
||||
self._auto_update_thread: Optional[threading.Thread] = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Feed registry
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_feeds(self) -> None:
|
||||
"""Register the built-in feeds."""
|
||||
api_keys = self.config.api_keys
|
||||
|
||||
self._feeds["malwarebazaar"] = MalwareBazaarFeed(
|
||||
api_key=api_keys.get("malwarebazaar"),
|
||||
)
|
||||
self._feeds["threatfox"] = ThreatFoxFeed()
|
||||
self._feeds["urlhaus"] = URLHausFeed()
|
||||
self._feeds["feodotracker"] = FeodoTrackerFeed()
|
||||
self._feeds["emergingthreats"] = EmergingThreatsFeed()
|
||||
self._feeds["virusshare"] = VirusShareFeed()
|
||||
|
||||
@property
|
||||
def feed_names(self) -> List[str]:
|
||||
return list(self._feeds.keys())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Update operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def update_all(self) -> Dict[str, Any]:
|
||||
"""Fetch from every registered feed and store results.
|
||||
|
||||
Returns a summary dict with per-feed statistics.
|
||||
"""
|
||||
self.hash_db.initialize()
|
||||
self.ioc_db.initialize()
|
||||
|
||||
summary: Dict[str, Any] = {"feeds": {}, "total_new": 0, "errors": []}
|
||||
|
||||
for name, feed in self._feeds.items():
|
||||
try:
|
||||
stats = self._update_single(name, feed)
|
||||
summary["feeds"][name] = stats
|
||||
summary["total_new"] += stats.get("inserted", 0)
|
||||
except Exception as exc:
|
||||
logger.exception("Feed '%s' failed", name)
|
||||
summary["feeds"][name] = {"error": str(exc)}
|
||||
summary["errors"].append(name)
|
||||
|
||||
event_bus.publish(EventType.SIGNATURE_UPDATED, {
|
||||
"source": "manager",
|
||||
"feeds_updated": len(summary["feeds"]) - len(summary["errors"]),
|
||||
"total_new": summary["total_new"],
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Signature update complete: %d feed(s), %d new entries, %d error(s)",
|
||||
len(self._feeds),
|
||||
summary["total_new"],
|
||||
len(summary["errors"]),
|
||||
)
|
||||
return summary
|
||||
|
||||
def update_feed(self, feed_name: str) -> Dict[str, Any]:
|
||||
"""Update a single feed by name.
|
||||
|
||||
Raises ``KeyError`` if *feed_name* is not registered.
|
||||
"""
|
||||
if feed_name not in self._feeds:
|
||||
raise KeyError(f"Unknown feed: {feed_name!r} (available: {self.feed_names})")
|
||||
|
||||
self.hash_db.initialize()
|
||||
self.ioc_db.initialize()
|
||||
|
||||
feed = self._feeds[feed_name]
|
||||
stats = self._update_single(feed_name, feed)
|
||||
|
||||
event_bus.publish(EventType.SIGNATURE_UPDATED, {
|
||||
"source": "manager",
|
||||
"feed": feed_name,
|
||||
"inserted": stats.get("inserted", 0),
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
def _update_single(self, name: str, feed: BaseFeed) -> Dict[str, Any]:
|
||||
"""Fetch from one feed and route entries to the right DB."""
|
||||
logger.info("Updating feed: %s", name)
|
||||
entries = feed.fetch()
|
||||
|
||||
hashes_added = 0
|
||||
ips_added = 0
|
||||
domains_added = 0
|
||||
urls_added = 0
|
||||
|
||||
# Classify and batch entries.
|
||||
hash_rows = []
|
||||
ip_rows = []
|
||||
domain_rows = []
|
||||
url_rows = []
|
||||
|
||||
for entry in entries:
|
||||
ioc_type = entry.get("ioc_type")
|
||||
|
||||
if ioc_type is None:
|
||||
# Hash-based entry (from MalwareBazaar).
|
||||
hash_rows.append((
|
||||
entry.get("hash", ""),
|
||||
entry.get("threat_name", ""),
|
||||
entry.get("threat_type", "MALWARE"),
|
||||
entry.get("severity", "HIGH"),
|
||||
entry.get("source", name),
|
||||
entry.get("details", ""),
|
||||
))
|
||||
elif ioc_type == "ip":
|
||||
ip_rows.append((
|
||||
entry.get("value", ""),
|
||||
entry.get("threat_name", ""),
|
||||
entry.get("type", "C2"),
|
||||
entry.get("source", name),
|
||||
))
|
||||
elif ioc_type == "domain":
|
||||
domain_rows.append((
|
||||
entry.get("value", ""),
|
||||
entry.get("threat_name", ""),
|
||||
entry.get("type", "C2"),
|
||||
entry.get("source", name),
|
||||
))
|
||||
elif ioc_type == "url":
|
||||
url_rows.append((
|
||||
entry.get("value", ""),
|
||||
entry.get("threat_name", ""),
|
||||
entry.get("type", "malware_distribution"),
|
||||
entry.get("source", name),
|
||||
))
|
||||
|
||||
if hash_rows:
|
||||
hashes_added = self.hash_db.bulk_add(hash_rows)
|
||||
if ip_rows:
|
||||
ips_added = self.ioc_db.bulk_add_ips(ip_rows)
|
||||
if domain_rows:
|
||||
domains_added = self.ioc_db.bulk_add_domains(domain_rows)
|
||||
if url_rows:
|
||||
urls_added = self.ioc_db.bulk_add_urls(url_rows)
|
||||
|
||||
total = hashes_added + ips_added + domains_added + urls_added
|
||||
|
||||
# Persist last-update timestamp.
|
||||
self.hash_db.set_meta(f"feed_{name}_updated", datetime.utcnow().isoformat())
|
||||
|
||||
logger.info(
|
||||
"Feed '%s': %d hashes, %d IPs, %d domains, %d URLs",
|
||||
name, hashes_added, ips_added, domains_added, urls_added,
|
||||
)
|
||||
|
||||
return {
|
||||
"feed": name,
|
||||
"fetched": len(entries),
|
||||
"inserted": total,
|
||||
"hashes": hashes_added,
|
||||
"ips": ips_added,
|
||||
"domains": domains_added,
|
||||
"urls": urls_added,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Status
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Return per-feed last-update times and aggregate stats."""
|
||||
self.hash_db.initialize()
|
||||
self.ioc_db.initialize()
|
||||
|
||||
feed_status: Dict[str, Any] = {}
|
||||
for name in self._feeds:
|
||||
last = self.hash_db.get_meta(f"feed_{name}_updated")
|
||||
feed_status[name] = {
|
||||
"last_updated": last,
|
||||
}
|
||||
|
||||
return {
|
||||
"db_path": str(self._db_path),
|
||||
"hash_count": self.hash_db.count(),
|
||||
"hash_stats": self.hash_db.get_stats(),
|
||||
"ioc_stats": self.ioc_db.get_stats(),
|
||||
"feeds": feed_status,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto-update
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def auto_update(self, interval_hours: int = 6) -> None:
|
||||
"""Start a background thread that periodically calls :meth:`update_all`.
|
||||
|
||||
Call :meth:`stop_auto_update` to stop the thread.
|
||||
"""
|
||||
if self._auto_update_thread and self._auto_update_thread.is_alive():
|
||||
logger.warning("Auto-update thread is already running")
|
||||
return
|
||||
|
||||
self._auto_update_stop.clear()
|
||||
|
||||
def _loop() -> None:
|
||||
logger.info("Auto-update started (every %d hours)", interval_hours)
|
||||
while not self._auto_update_stop.is_set():
|
||||
try:
|
||||
self.update_all()
|
||||
except Exception:
|
||||
logger.exception("Auto-update cycle failed")
|
||||
self._auto_update_stop.wait(timeout=interval_hours * 3600)
|
||||
logger.info("Auto-update stopped")
|
||||
|
||||
self._auto_update_thread = threading.Thread(
|
||||
target=_loop, name="ayn-auto-update", daemon=True
|
||||
)
|
||||
self._auto_update_thread.start()
|
||||
|
||||
def stop_auto_update(self) -> None:
|
||||
"""Signal the auto-update thread to stop."""
|
||||
self._auto_update_stop.set()
|
||||
if self._auto_update_thread:
|
||||
self._auto_update_thread.join(timeout=5)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Integrity
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def verify_db_integrity(self) -> Dict[str, Any]:
|
||||
"""Run ``PRAGMA integrity_check`` on the database.
|
||||
|
||||
Returns a dict with ``ok`` (bool) and ``details`` (str).
|
||||
"""
|
||||
self.hash_db.initialize()
|
||||
|
||||
try:
|
||||
result = self.hash_db.conn.execute("PRAGMA integrity_check").fetchone()
|
||||
ok = result[0] == "ok" if result else False
|
||||
detail = result[0] if result else "no result"
|
||||
except sqlite3.DatabaseError as exc:
|
||||
ok = False
|
||||
detail = str(exc)
|
||||
|
||||
status = {"ok": ok, "details": detail}
|
||||
if not ok:
|
||||
logger.error("Database integrity check FAILED: %s", detail)
|
||||
else:
|
||||
logger.info("Database integrity check passed")
|
||||
return status
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def close(self) -> None:
|
||||
"""Stop background threads and close databases."""
|
||||
self.stop_auto_update()
|
||||
self.hash_db.close()
|
||||
self.ioc_db.close()
|
||||
Reference in New Issue
Block a user