321 lines
11 KiB
Python
321 lines
11 KiB
Python
"""Signature manager for AYN Antivirus.
|
|
|
|
Orchestrates all threat-intelligence feeds, routes fetched entries into the
|
|
correct database (hash DB or IOC DB), and exposes high-level update /
|
|
status / integrity operations for the CLI and scheduler.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
import threading
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from ayn_antivirus.config import Config
|
|
from ayn_antivirus.constants import DEFAULT_DB_PATH
|
|
from ayn_antivirus.core.event_bus import EventType, event_bus
|
|
from ayn_antivirus.signatures.db.hash_db import HashDatabase
|
|
from ayn_antivirus.signatures.db.ioc_db import IOCDatabase
|
|
from ayn_antivirus.signatures.feeds.base_feed import BaseFeed
|
|
from ayn_antivirus.signatures.feeds.emergingthreats import EmergingThreatsFeed
|
|
from ayn_antivirus.signatures.feeds.feodotracker import FeodoTrackerFeed
|
|
from ayn_antivirus.signatures.feeds.malwarebazaar import MalwareBazaarFeed
|
|
from ayn_antivirus.signatures.feeds.threatfox import ThreatFoxFeed
|
|
from ayn_antivirus.signatures.feeds.urlhaus import URLHausFeed
|
|
from ayn_antivirus.signatures.feeds.virusshare import VirusShareFeed
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SignatureManager:
|
|
"""Central coordinator for signature / IOC updates.
|
|
|
|
Parameters
|
|
----------
|
|
config:
|
|
Application configuration.
|
|
db_path:
|
|
Override the database path from config.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: Config,
|
|
db_path: Optional[str | Path] = None,
|
|
) -> None:
|
|
self.config = config
|
|
self._db_path = Path(db_path or config.db_path)
|
|
|
|
# Databases.
|
|
self.hash_db = HashDatabase(self._db_path)
|
|
self.ioc_db = IOCDatabase(self._db_path)
|
|
|
|
# Feeds — instantiated lazily so missing API keys don't crash init.
|
|
self._feeds: Dict[str, BaseFeed] = {}
|
|
self._init_feeds()
|
|
|
|
# Auto-update thread handle.
|
|
self._auto_update_stop = threading.Event()
|
|
self._auto_update_thread: Optional[threading.Thread] = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Feed registry
|
|
# ------------------------------------------------------------------
|
|
|
|
def _init_feeds(self) -> None:
|
|
"""Register the built-in feeds."""
|
|
api_keys = self.config.api_keys
|
|
|
|
self._feeds["malwarebazaar"] = MalwareBazaarFeed(
|
|
api_key=api_keys.get("malwarebazaar"),
|
|
)
|
|
self._feeds["threatfox"] = ThreatFoxFeed()
|
|
self._feeds["urlhaus"] = URLHausFeed()
|
|
self._feeds["feodotracker"] = FeodoTrackerFeed()
|
|
self._feeds["emergingthreats"] = EmergingThreatsFeed()
|
|
self._feeds["virusshare"] = VirusShareFeed()
|
|
|
|
@property
|
|
def feed_names(self) -> List[str]:
|
|
return list(self._feeds.keys())
|
|
|
|
# ------------------------------------------------------------------
|
|
# Update operations
|
|
# ------------------------------------------------------------------
|
|
|
|
def update_all(self) -> Dict[str, Any]:
|
|
"""Fetch from every registered feed and store results.
|
|
|
|
Returns a summary dict with per-feed statistics.
|
|
"""
|
|
self.hash_db.initialize()
|
|
self.ioc_db.initialize()
|
|
|
|
summary: Dict[str, Any] = {"feeds": {}, "total_new": 0, "errors": []}
|
|
|
|
for name, feed in self._feeds.items():
|
|
try:
|
|
stats = self._update_single(name, feed)
|
|
summary["feeds"][name] = stats
|
|
summary["total_new"] += stats.get("inserted", 0)
|
|
except Exception as exc:
|
|
logger.exception("Feed '%s' failed", name)
|
|
summary["feeds"][name] = {"error": str(exc)}
|
|
summary["errors"].append(name)
|
|
|
|
event_bus.publish(EventType.SIGNATURE_UPDATED, {
|
|
"source": "manager",
|
|
"feeds_updated": len(summary["feeds"]) - len(summary["errors"]),
|
|
"total_new": summary["total_new"],
|
|
})
|
|
|
|
logger.info(
|
|
"Signature update complete: %d feed(s), %d new entries, %d error(s)",
|
|
len(self._feeds),
|
|
summary["total_new"],
|
|
len(summary["errors"]),
|
|
)
|
|
return summary
|
|
|
|
def update_feed(self, feed_name: str) -> Dict[str, Any]:
|
|
"""Update a single feed by name.
|
|
|
|
Raises ``KeyError`` if *feed_name* is not registered.
|
|
"""
|
|
if feed_name not in self._feeds:
|
|
raise KeyError(f"Unknown feed: {feed_name!r} (available: {self.feed_names})")
|
|
|
|
self.hash_db.initialize()
|
|
self.ioc_db.initialize()
|
|
|
|
feed = self._feeds[feed_name]
|
|
stats = self._update_single(feed_name, feed)
|
|
|
|
event_bus.publish(EventType.SIGNATURE_UPDATED, {
|
|
"source": "manager",
|
|
"feed": feed_name,
|
|
"inserted": stats.get("inserted", 0),
|
|
})
|
|
|
|
return stats
|
|
|
|
def _update_single(self, name: str, feed: BaseFeed) -> Dict[str, Any]:
|
|
"""Fetch from one feed and route entries to the right DB."""
|
|
logger.info("Updating feed: %s", name)
|
|
entries = feed.fetch()
|
|
|
|
hashes_added = 0
|
|
ips_added = 0
|
|
domains_added = 0
|
|
urls_added = 0
|
|
|
|
# Classify and batch entries.
|
|
hash_rows = []
|
|
ip_rows = []
|
|
domain_rows = []
|
|
url_rows = []
|
|
|
|
for entry in entries:
|
|
ioc_type = entry.get("ioc_type")
|
|
|
|
if ioc_type is None:
|
|
# Hash-based entry (from MalwareBazaar).
|
|
hash_rows.append((
|
|
entry.get("hash", ""),
|
|
entry.get("threat_name", ""),
|
|
entry.get("threat_type", "MALWARE"),
|
|
entry.get("severity", "HIGH"),
|
|
entry.get("source", name),
|
|
entry.get("details", ""),
|
|
))
|
|
elif ioc_type == "ip":
|
|
ip_rows.append((
|
|
entry.get("value", ""),
|
|
entry.get("threat_name", ""),
|
|
entry.get("type", "C2"),
|
|
entry.get("source", name),
|
|
))
|
|
elif ioc_type == "domain":
|
|
domain_rows.append((
|
|
entry.get("value", ""),
|
|
entry.get("threat_name", ""),
|
|
entry.get("type", "C2"),
|
|
entry.get("source", name),
|
|
))
|
|
elif ioc_type == "url":
|
|
url_rows.append((
|
|
entry.get("value", ""),
|
|
entry.get("threat_name", ""),
|
|
entry.get("type", "malware_distribution"),
|
|
entry.get("source", name),
|
|
))
|
|
|
|
if hash_rows:
|
|
hashes_added = self.hash_db.bulk_add(hash_rows)
|
|
if ip_rows:
|
|
ips_added = self.ioc_db.bulk_add_ips(ip_rows)
|
|
if domain_rows:
|
|
domains_added = self.ioc_db.bulk_add_domains(domain_rows)
|
|
if url_rows:
|
|
urls_added = self.ioc_db.bulk_add_urls(url_rows)
|
|
|
|
total = hashes_added + ips_added + domains_added + urls_added
|
|
|
|
# Persist last-update timestamp.
|
|
self.hash_db.set_meta(f"feed_{name}_updated", datetime.utcnow().isoformat())
|
|
|
|
logger.info(
|
|
"Feed '%s': %d hashes, %d IPs, %d domains, %d URLs",
|
|
name, hashes_added, ips_added, domains_added, urls_added,
|
|
)
|
|
|
|
return {
|
|
"feed": name,
|
|
"fetched": len(entries),
|
|
"inserted": total,
|
|
"hashes": hashes_added,
|
|
"ips": ips_added,
|
|
"domains": domains_added,
|
|
"urls": urls_added,
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Status
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_status(self) -> Dict[str, Any]:
|
|
"""Return per-feed last-update times and aggregate stats."""
|
|
self.hash_db.initialize()
|
|
self.ioc_db.initialize()
|
|
|
|
feed_status: Dict[str, Any] = {}
|
|
for name in self._feeds:
|
|
last = self.hash_db.get_meta(f"feed_{name}_updated")
|
|
feed_status[name] = {
|
|
"last_updated": last,
|
|
}
|
|
|
|
return {
|
|
"db_path": str(self._db_path),
|
|
"hash_count": self.hash_db.count(),
|
|
"hash_stats": self.hash_db.get_stats(),
|
|
"ioc_stats": self.ioc_db.get_stats(),
|
|
"feeds": feed_status,
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Auto-update
|
|
# ------------------------------------------------------------------
|
|
|
|
def auto_update(self, interval_hours: int = 6) -> None:
|
|
"""Start a background thread that periodically calls :meth:`update_all`.
|
|
|
|
Call :meth:`stop_auto_update` to stop the thread.
|
|
"""
|
|
if self._auto_update_thread and self._auto_update_thread.is_alive():
|
|
logger.warning("Auto-update thread is already running")
|
|
return
|
|
|
|
self._auto_update_stop.clear()
|
|
|
|
def _loop() -> None:
|
|
logger.info("Auto-update started (every %d hours)", interval_hours)
|
|
while not self._auto_update_stop.is_set():
|
|
try:
|
|
self.update_all()
|
|
except Exception:
|
|
logger.exception("Auto-update cycle failed")
|
|
self._auto_update_stop.wait(timeout=interval_hours * 3600)
|
|
logger.info("Auto-update stopped")
|
|
|
|
self._auto_update_thread = threading.Thread(
|
|
target=_loop, name="ayn-auto-update", daemon=True
|
|
)
|
|
self._auto_update_thread.start()
|
|
|
|
def stop_auto_update(self) -> None:
|
|
"""Signal the auto-update thread to stop."""
|
|
self._auto_update_stop.set()
|
|
if self._auto_update_thread:
|
|
self._auto_update_thread.join(timeout=5)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Integrity
|
|
# ------------------------------------------------------------------
|
|
|
|
def verify_db_integrity(self) -> Dict[str, Any]:
|
|
"""Run ``PRAGMA integrity_check`` on the database.
|
|
|
|
Returns a dict with ``ok`` (bool) and ``details`` (str).
|
|
"""
|
|
self.hash_db.initialize()
|
|
|
|
try:
|
|
result = self.hash_db.conn.execute("PRAGMA integrity_check").fetchone()
|
|
ok = result[0] == "ok" if result else False
|
|
detail = result[0] if result else "no result"
|
|
except sqlite3.DatabaseError as exc:
|
|
ok = False
|
|
detail = str(exc)
|
|
|
|
status = {"ok": ok, "details": detail}
|
|
if not ok:
|
|
logger.error("Database integrity check FAILED: %s", detail)
|
|
else:
|
|
logger.info("Database integrity check passed")
|
|
return status
|
|
|
|
# ------------------------------------------------------------------
|
|
# Cleanup
|
|
# ------------------------------------------------------------------
|
|
|
|
def close(self) -> None:
|
|
"""Stop background threads and close databases."""
|
|
self.stop_auto_update()
|
|
self.hash_db.close()
|
|
self.ioc_db.close()
|