Skip to content

Instantly share code, notes, and snippets.

@gary23w
Last active February 28, 2026 12:11
Show Gist options
  • Select an option

  • Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.

Select an option

Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.
simple-packet-sniffer.py
"""
simplepacketsniffer.py 0.4
Usage examples:
Dataset Building Mode (with LLM labeling):
sudo python simplepacketsniffer.py --build-dataset --llm-label --num-samples 10 --dataset-out training_data.csv -i INTERFACE -v INFO [--xdp]
Training Mode:
sudo python simplepacketsniffer.py --train --dataset training_data.csv --model-path model.pkl -v INFO
Sniffer (Protection) Mode:
sudo python simplepacketsniffer.py -i INTERFACE --red-alert --whitelist US,CA --sentry --model-path model.pkl -v DEBUG [--xdp]
Expensive Mode (Real-time LLM analysis for every packet):
sudo python simplepacketsniffer.py --expensive -i INTERFACE -v DEBUG [--xdp]
***logging.yaml
version: 1
disable_existing_loggers: false
formatters:
standard:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
simple:
format: '%(asctime)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: standard
stream: ext://sys.stdout
file_handler:
class: logging.handlers.RotatingFileHandler
level: DEBUG
formatter: standard
filename: sniffer.log
maxBytes: 1048576
backupCount: 3
flagged_ip_handler:
class: logging.handlers.RotatingFileHandler
level: WARNING
formatter: standard
filename: flagged_ips.log
maxBytes: 1048576
backupCount: 3
dns_handler:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: standard
filename: dns_queries.log
maxBytes: 1048576
backupCount: 3
loggers:
PacketSniffer:
level: DEBUG
handlers: [console, file_handler]
propagate: no
FlaggedIPLogger:
level: WARNING
handlers: [flagged_ip_handler, console]
propagate: no
DNSQueryLogger:
level: INFO
handlers: [dns_handler, console]
propagate: no
root:
level: DEBUG
handlers: [console]
***
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import logging.config
import subprocess
import sys
import platform
import re
import ipaddress
import time
import threading
import struct
import os
import csv
import string
import gzip
import zlib
import hashlib
import binascii
import shutil
from dataclasses import dataclass, field
from datetime import datetime
from collections import Counter
from collections import defaultdict, deque
from typing import Tuple, Dict, Any, List, Optional
try:
import aiohttp
except ImportError:
aiohttp = None
try:
import yaml
except ImportError:
yaml = None
try:
from cachetools import TTLCache
except ImportError:
TTLCache = None
try:
import numpy as np
except ImportError:
np = None
try:
import scapy.all as scapy
from scapy.all import sniff, Ether, IP, IPv6, TCP, UDP, ICMP, ARP, Raw
from scapy.layers.dns import DNS, DNSQR
except ImportError:
scapy = None
sniff = None
Ether = None
IP = None
IPv6 = None
TCP = None
UDP = None
ICMP = None
ARP = None
Raw = None
DNS = None
DNSQR = None
try:
from scapy.layers.tls.all import TLS, TLSClientHello, TLSHandshake
except ImportError:
TLS = None
TLSClientHello = None
TLSHandshake = None
try:
from openai import AsyncOpenAI
except ImportError:
AsyncOpenAI = None
try:
from bcc import BPF
except ImportError:
BPF = None
try:
import yara
except ImportError:
yara = None
last_xdp_metadata: Optional[Dict[str, Any]] = None
def setup_logging(config_file: str = "logging.yaml") -> None:
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
try:
if yaml is None:
raise ImportError("PyYAML is not installed")
with open(config_file, "r") as f:
config = yaml.safe_load(f.read())
logging.config.dictConfig(config)
except Exception as e:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logging.getLogger(__name__).warning(
"Failed to load logging config from %s, using basic config: %s",
config_file, e
)
def check_packet_capture_backend(logger: Optional[logging.Logger] = None) -> Tuple[bool, str]:
logger = logger or logging.getLogger(__name__)
if scapy is None:
return False, "Scapy is not installed. Run with --install first."
if os.name == "nt":
use_pcap = bool(getattr(scapy.conf, "use_pcap", False))
l2listen_str = str(getattr(scapy.conf, "L2listen", ""))
if (not use_pcap) or ("_NotAvailableSocket" in l2listen_str) or ("wpcap.dll missing" in l2listen_str.lower()):
return (
False,
"Npcap/WinPcap is not installed or not available. Install Npcap from https://npcap.com/ and rerun."
)
return True, ""
class AsyncRunner:
def __init__(self) -> None:
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
def run_coroutine(self, coro):
return asyncio.run_coroutine_threadsafe(coro, self.loop)
class IPLookup:
def __init__(self, ttl: int = 3600, cache_size: int = 1000, logger: Optional[logging.Logger] = None) -> None:
if TTLCache is None or aiohttp is None:
raise ImportError("IPLookup requires 'cachetools' and 'aiohttp'.")
self.cache = TTLCache(maxsize=cache_size, ttl=ttl)
self.logger = logger or logging.getLogger(__name__)
async def fetch_ip_info(self, session: aiohttp.ClientSession, ip: str) -> Dict[str, Any]:
if ip in self.cache:
return self.cache[ip]
try:
async with session.get(f"https://ipinfo.io/{ip}/json", timeout=5) as resp:
if resp.status == 200:
data = await resp.json()
self.cache[ip] = data
return data
else:
self.logger.warning("IP lookup for %s returned status %s", ip, resp.status)
except aiohttp.ClientError as e:
self.logger.exception("Network error during IP lookup for %s: %s", ip, e)
except Exception:
self.logger.exception("Unexpected error during IP lookup for %s", ip)
return {}
async def get_ip_info(self, ip: str) -> Dict[str, Any]:
async with aiohttp.ClientSession() as session:
return await self.fetch_ip_info(session, ip)
class DNSParser:
def __init__(self, logger: Optional[logging.Logger] = None, blacklist: Optional[List[str]] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.blacklist = set(blacklist or ['ipinfo.io', 'ipinfo.io.']) # TODO: add wildcard domains
def is_blacklisted(self, domain: str) -> bool:
return domain in self.blacklist
def parse_dns_name(self, payload: bytes, offset: int) -> Tuple[str, int]:
labels = []
while True:
if offset >= len(payload):
break
length = payload[offset]
if (length & 0xC0) == 0xC0:
if offset + 1 >= len(payload):
break
pointer = ((length & 0x3F) << 8) | payload[offset + 1]
pointed_name, _ = self.parse_dns_name(payload, pointer)
labels.append(pointed_name)
offset += 2
break
if length == 0:
offset += 1
break
offset += 1
label = payload[offset: offset + length].decode('utf-8', errors='replace')
labels.append(label)
offset += length
domain_name = ".".join(labels)
return domain_name, offset
def parse_dns_payload(self, payload: bytes) -> None:
self.logger.info("Parsing DNS payload...")
if len(payload) < 12:
self.logger.warning("DNS payload too short to parse.")
return
try:
transaction_id, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12])
self.logger.info("DNS Header: ID=%#04x, Flags=%#04x, QD=%d, AN=%d, NS=%d, AR=%d",
transaction_id, flags, qdcount, ancount, nscount, arcount)
except Exception as e:
self.logger.exception("Error parsing DNS header: %s", e)
return
offset = 12
for i in range(qdcount):
try:
domain, offset = self.parse_dns_name(payload, offset)
if offset + 4 > len(payload):
self.logger.warning("DNS question truncated.")
break
qtype, qclass = struct.unpack("!HH", payload[offset:offset + 4])
offset += 4
self.logger.info("DNS Question %d: %s, type: %d, class: %d", i + 1, domain, qtype, qclass)
except Exception as e:
self.logger.exception("Error parsing DNS question: %s", e)
break
class PayloadParser:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.SIGNATURES = {
"504B0304": "ZIP archive / Office Open XML document (DOCX, XLSX, PPTX)",
"504B030414000600": "Office Open XML (DOCX, XLSX, PPTX) extended header",
"1F8B08": "GZIP archive",
"377ABCAF271C": "7-Zip archive",
"52617221": "RAR archive",
"425A68": "BZIP2 archive",
"213C617263683E0A": "Ar (UNIX archive) / Debian package",
"7F454C46": "ELF executable (Unix/Linux)",
"4D5A": "Windows executable (EXE, MZ header / DLL)",
"CAFEBABE": "Java class file or Mach-O Fat Binary (ambiguous)",
"FEEDFACE": "Mach-O executable (32-bit, little-endian)",
"CEFAEDFE": "Mach-O executable (32-bit, big-endian)",
"FEEDFACF": "Mach-O executable (64-bit, little-endian)",
"CFFAEDFE": "Mach-O executable (64-bit, big-endian)",
"BEBAFECA": "Mach-O Fat Binary (little endian)",
"4C000000": "Windows shortcut file (.lnk)",
"4D534346": "Microsoft Cabinet file (CAB)",
"D0CF11E0": "Microsoft Office legacy format (DOC, XLS, PPT)",
"25504446": "PDF document",
"7B5C727466": "RTF document (starting with '{\\rtf')",
"3C3F786D6C": "XML file (<?xml)",
"3C68746D6C3E": "HTML file",
"252150532D41646F6265": "PostScript/EPS document (starts with '%!PS-Adobe')",
"4D2D2D2D": "PostScript file (---)",
"89504E47": "PNG image",
"47494638": "GIF image",
"FFD8FF": "JPEG image (general)",
"FFD8FFE0": "JPEG image (JFIF)",
"FFD8FFE1": "JPEG image (EXIF)",
"424D": "Bitmap image (BMP)",
"49492A00": "TIFF image (little endian / Intel)",
"4D4D002A": "TIFF image (big endian / Motorola)",
"38425053": "Adobe Photoshop document (PSD)",
"00000100": "ICO icon file",
"00000200": "CUR cursor file",
"494433": "MP3 audio (ID3 tag)",
"000001BA": "MPEG video (VCD)",
"000001B3": "MPEG video",
"66747970": "MP4/MOV file (ftyp)",
"4D546864": "MIDI file",
"464F524D": "AIFF audio file",
"52494646": "AVI file (RIFF) [Also used in WAV]",
"664C6143": "FLAC audio file",
"4F676753": "OGG container file (OggS)",
"53514C69": "SQLite database file (SQLite format 3)",
"420D0D0A": "Python compiled file (.pyc) [example magic, may vary]",
"6465780A": "Android Dalvik Executable (DEX) file",
"EDABEEDB": "RPM package file",
"786172210D0A1A0A": "XAR archive (macOS installer package)",
}
def normalize_payload(self, payload: bytes) -> str:
"""
Attempt to decode the payload as UTF-8 text and replace non-printable characters.
If decoding fails or results in gibberish, fallback to a hex representation.
"""
try:
text = payload.decode('utf-8', errors='ignore')
normalized = ''.join(ch if ch in string.printable else '.' for ch in text)
if sum(1 for ch in normalized if ch == '.') > len(normalized) * 0.5:
return payload.hex()
return normalized
except Exception:
return payload.hex()
def parse_http_payload(self, payload: bytes) -> None:
try:
text = payload.decode('utf-8', errors='replace')
text = self._make_log_safe(text)
self.logger.info("HTTP Payload:\n%s", text)
except Exception as e:
self.logger.exception("Error parsing HTTP payload: %s", e)
def parse_text_payload(self, payload: bytes) -> None:
try:
text = payload.decode('utf-8', errors='replace')
text = self._make_log_safe(text)
self.logger.info("Text Payload:\n%s", text)
except Exception as e:
self.logger.exception("Error decoding text payload: %s", e)
def parse_tls_payload(self, payload: bytes) -> None:
self.logger.info("TLS Payload (hex):\n%s", payload.hex())
def analyze_hex_dump(self, payload: bytes) -> List[Tuple[str, str]]:
head_hex = payload[:23].hex().upper()
full_hex = payload.hex().upper()
self.logger.info("Analyzing payload hex dump for signatures...")
found = []
for sig, desc in self.SIGNATURES.items():
if len(sig) <= 8:
matched = head_hex.startswith(sig)
else:
matched = head_hex.startswith(sig) or sig in full_hex
if matched:
found.append((sig, desc))
self.logger.warning("Detected signature %s: %s", sig, desc)
return found
def _make_log_safe(self, text: str) -> str:
target_encoding = getattr(sys.stdout, "encoding", None) or "utf-8"
return text.encode(target_encoding, errors="replace").decode(target_encoding, errors="replace")
@dataclass(frozen=True)
class FlowKey:
src_ip: str
src_port: int
dst_ip: str
dst_port: int
proto: str
def canonical_flow_key(src_ip: str, src_port: int, dst_ip: str, dst_port: int, proto: str = "TCP") -> Tuple[FlowKey, bool]:
left = (src_ip, src_port)
right = (dst_ip, dst_port)
if left <= right:
return FlowKey(src_ip, src_port, dst_ip, dst_port, proto), True
return FlowKey(dst_ip, dst_port, src_ip, src_port, proto), False
@dataclass
class StreamDirectionState:
base_seq: Optional[int] = None
next_seq: Optional[int] = None
contiguous: bytearray = field(default_factory=bytearray)
out_of_order: Dict[int, bytes] = field(default_factory=dict)
def add(self, seq: int, payload: bytes) -> int:
if not payload:
return 0
if self.base_seq is None:
self.base_seq = seq
self.next_seq = seq
appended = 0
if self.next_seq is None:
self.next_seq = seq
if seq < self.next_seq:
overlap = self.next_seq - seq
if overlap >= len(payload):
return 0
payload = payload[overlap:]
seq = self.next_seq
if seq == self.next_seq:
self.contiguous.extend(payload)
appended += len(payload)
self.next_seq += len(payload)
while self.next_seq in self.out_of_order:
chunk = self.out_of_order.pop(self.next_seq)
self.contiguous.extend(chunk)
appended += len(chunk)
self.next_seq += len(chunk)
return appended
existing = self.out_of_order.get(seq)
if existing is None or len(payload) > len(existing):
self.out_of_order[seq] = payload
return 0
@dataclass
class StreamState:
client_to_server: StreamDirectionState = field(default_factory=StreamDirectionState)
server_to_client: StreamDirectionState = field(default_factory=StreamDirectionState)
created_at: float = field(default_factory=time.time)
last_seen: float = field(default_factory=time.time)
class TCPStreamReassembler:
def __init__(self, logger: Optional[logging.Logger] = None, max_streams: int = 10000, stream_ttl: int = 1800) -> None:
self.logger = logger or logging.getLogger(__name__)
self.max_streams = max_streams
self.stream_ttl = stream_ttl
self.streams: Dict[FlowKey, StreamState] = {}
def _evict_old(self) -> None:
now = time.time()
stale_keys = [k for k, v in self.streams.items() if now - v.last_seen > self.stream_ttl]
for key in stale_keys:
self.streams.pop(key, None)
if len(self.streams) > self.max_streams:
for key, _ in sorted(self.streams.items(), key=lambda item: item[1].last_seen)[:len(self.streams) - self.max_streams]:
self.streams.pop(key, None)
def add_tcp_segment(self, src_ip: str, src_port: int, dst_ip: str, dst_port: int, seq: int, payload: bytes) -> Dict[str, Any]:
key, forward_is_client = canonical_flow_key(src_ip, src_port, dst_ip, dst_port)
stream = self.streams.setdefault(key, StreamState())
stream.last_seen = time.time()
state = stream.client_to_server if forward_is_client else stream.server_to_client
added = state.add(seq, payload)
self._evict_old()
return {
"flow_key": key,
"is_client_to_server": forward_is_client,
"bytes_added": added,
"client_stream": bytes(stream.client_to_server.contiguous),
"server_stream": bytes(stream.server_to_client.contiguous),
}
class ProtocolDecoder:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
def decode_http(self, payload: bytes, direction: str) -> Dict[str, Any]:
text = payload.decode("iso-8859-1", errors="replace")
head, _, body = text.partition("\r\n\r\n")
lines = head.split("\r\n") if head else []
headers: Dict[str, str] = {}
for line in lines[1:]:
if ":" in line:
k, v = line.split(":", 1)
headers[k.strip().lower()] = v.strip()
normalized = {
"protocol": "http1",
"direction": direction,
"start_line": lines[0] if lines else "",
"headers": headers,
"body_bytes": body.encode("iso-8859-1", errors="ignore"),
"host": headers.get("host", ""),
"content_type": headers.get("content-type", ""),
"transfer_encoding": headers.get("transfer-encoding", "").lower(),
"content_encoding": headers.get("content-encoding", "").lower(),
}
return normalized
def decode_dns(self, payload: bytes) -> Dict[str, Any]:
if len(payload) < 12:
return {"protocol": "dns", "error": "short_payload"}
tid, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12])
qr = (flags >> 15) & 0x1
opcode = (flags >> 11) & 0xF
rcode = flags & 0xF
query_name = ""
if qdcount > 0:
try:
query_name, _ = DNSParser(logger=self.logger).parse_dns_name(payload, 12)
except Exception:
query_name = ""
return {
"protocol": "dns",
"transaction_id": tid,
"is_response": bool(qr),
"opcode": opcode,
"rcode": rcode,
"qdcount": qdcount,
"ancount": ancount,
"nscount": nscount,
"arcount": arcount,
"query": query_name,
}
def _parse_tls_client_hello(self, payload: bytes) -> Dict[str, Any]:
if len(payload) < 5 or payload[0] != 0x16:
return {}
rec_len = int.from_bytes(payload[3:5], "big")
rec = payload[5:5 + rec_len]
if len(rec) < 4 or rec[0] != 0x01:
return {}
hs_len = int.from_bytes(rec[1:4], "big")
body = rec[4:4 + hs_len]
if len(body) < 34:
return {}
idx = 34
if idx >= len(body):
return {}
sess_len = body[idx]
idx += 1 + sess_len
if idx + 2 > len(body):
return {}
ciphers_len = int.from_bytes(body[idx:idx + 2], "big")
idx += 2
ciphers_raw = body[idx:idx + ciphers_len]
idx += ciphers_len
if idx >= len(body):
return {}
comp_len = body[idx]
idx += 1 + comp_len
exts: List[int] = []
curves: List[int] = []
ec_pf: List[int] = []
sni = ""
if idx + 2 <= len(body):
ext_total = int.from_bytes(body[idx:idx + 2], "big")
idx += 2
end = min(len(body), idx + ext_total)
while idx + 4 <= end:
etype = int.from_bytes(body[idx:idx + 2], "big")
elen = int.from_bytes(body[idx + 2:idx + 4], "big")
idx += 4
eval_data = body[idx:idx + elen]
idx += elen
exts.append(etype)
if etype == 0 and len(eval_data) >= 5:
list_len = int.from_bytes(eval_data[0:2], "big")
if 2 + list_len <= len(eval_data) and eval_data[2] == 0:
name_len = int.from_bytes(eval_data[3:5], "big")
if 5 + name_len <= len(eval_data):
sni = eval_data[5:5 + name_len].decode("utf-8", errors="replace")
elif etype == 10 and len(eval_data) >= 2:
glen = int.from_bytes(eval_data[0:2], "big")
data = eval_data[2:2 + glen]
curves = [int.from_bytes(data[i:i + 2], "big") for i in range(0, len(data), 2) if i + 2 <= len(data)]
elif etype == 11 and len(eval_data) >= 1:
flen = eval_data[0]
data = eval_data[1:1 + flen]
ec_pf = [b for b in data]
ciphers = [int.from_bytes(ciphers_raw[i:i + 2], "big") for i in range(0, len(ciphers_raw), 2) if i + 2 <= len(ciphers_raw)]
ja3_string = ",".join([
"771",
"-".join(str(c) for c in ciphers),
"-".join(str(e) for e in exts),
"-".join(str(c) for c in curves),
"-".join(str(p) for p in ec_pf),
])
ja3_hash = hashlib.md5(ja3_string.encode("utf-8", errors="ignore")).hexdigest()
return {
"protocol": "tls",
"record_type": "handshake",
"sni": sni,
"ja3": ja3_string,
"ja3_hash": ja3_hash,
}
def decode_smtp(self, payload: bytes) -> Dict[str, Any]:
text = payload.decode("utf-8", errors="replace")
lines = [line.strip() for line in text.splitlines() if line.strip()]
commands = [line.split(" ", 1)[0].upper() for line in lines[:25]]
return {"protocol": "smtp", "commands": commands, "line_count": len(lines)}
def decode_smb(self, payload: bytes) -> Dict[str, Any]:
if len(payload) >= 4 and payload[:4] in (b"\xfeSMB", b"\xffSMB"):
return {"protocol": "smb", "signature": payload[:4].hex(), "command": payload[4] if len(payload) > 4 else None}
return {"protocol": "smb", "error": "not_smb"}
def decode_by_ports_or_signature(self, src_port: int, dst_port: int, payload: bytes, direction: str) -> Dict[str, Any]:
ports = {src_port, dst_port}
if not payload:
return {"protocol": "unknown"}
if 53 in ports:
return self.decode_dns(payload)
if payload.startswith((b"GET ", b"POST ", b"PUT ", b"DELETE ", b"HEAD ", b"OPTIONS ", b"HTTP/1.")) or 80 in ports or 8080 in ports:
return self.decode_http(payload, direction)
if payload[:1] == b"\x16" or 443 in ports:
tls_data = self._parse_tls_client_hello(payload)
return tls_data if tls_data else {"protocol": "tls", "record_type": "unknown"}
if 25 in ports or payload[:5].upper() in (b"HELO ", b"EHLO ", b"MAIL ", b"RCPT ", b"DATA\r"):
return self.decode_smtp(payload)
if 445 in ports or payload.startswith((b"\xfeSMB", b"\xffSMB")):
return self.decode_smb(payload)
return {"protocol": "unknown"}
class ContentDecoder:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
def decode_chunked(self, body: bytes) -> bytes:
out = bytearray()
idx = 0
while idx < len(body):
end = body.find(b"\r\n", idx)
if end == -1:
break
size_line = body[idx:end].split(b";", 1)[0].strip()
try:
size = int(size_line, 16)
except ValueError:
break
idx = end + 2
if size == 0:
break
out.extend(body[idx:idx + size])
idx += size + 2
return bytes(out)
def decode_http_content(self, decoded_http: Dict[str, Any]) -> bytes:
body = decoded_http.get("body_bytes", b"")
if decoded_http.get("transfer_encoding") == "chunked":
body = self.decode_chunked(body)
content_encoding = decoded_http.get("content_encoding", "")
try:
if "gzip" in content_encoding:
body = gzip.decompress(body)
elif "deflate" in content_encoding:
body = zlib.decompress(body)
except Exception as exc:
self.logger.debug("Content decoding failed: %s", exc)
return body
def try_base64(self, payload: bytes) -> bytes:
cleaned = b"".join(payload.split())
if len(cleaned) < 16:
return b""
if not re.fullmatch(rb"[A-Za-z0-9+/=]+", cleaned):
return b""
try:
return binascii.a2b_base64(cleaned)
except Exception:
return b""
def detect_container(self, payload: bytes) -> str:
if payload.startswith(b"PK\x03\x04"):
return "zip"
if payload.startswith(b"%PDF-"):
return "pdf"
if payload.startswith(b"MZ"):
return "pe"
if payload.startswith(bytes.fromhex("D0CF11E0A1B11AE1")):
return "ole"
return "unknown"
class FileObjectCarver:
def __init__(self, logger: Optional[logging.Logger] = None, output_dir: str = "carved_objects") -> None:
self.logger = logger or logging.getLogger(__name__)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
def carve(self, flow_key: FlowKey, payload: bytes, source: str) -> Optional[Dict[str, Any]]:
container = ContentDecoder(logger=self.logger).detect_container(payload)
if container == "unknown" or len(payload) < 64:
return None
timestamp = datetime.utcnow().strftime("%Y%m%dT%H%M%S%f")
filename = f"{flow_key.src_ip}_{flow_key.src_port}_{flow_key.dst_ip}_{flow_key.dst_port}_{source}_{timestamp}.{container}"
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", filename)
path = os.path.join(self.output_dir, safe_name)
with open(path, "wb") as f:
f.write(payload)
sha256 = hashlib.sha256(payload).hexdigest()
return {"path": path, "sha256": sha256, "container": container, "size": len(payload)}
class SignatureScanner:
def __init__(self, logger: Optional[logging.Logger] = None, yara_rules_path: str = "") -> None:
self.logger = logger or logging.getLogger(__name__)
self.yara_rules_path = yara_rules_path
self.yara_rules = None
if yara is not None and yara_rules_path and os.path.exists(yara_rules_path):
try:
self.yara_rules = yara.compile(filepath=yara_rules_path)
self.logger.info("YARA rules loaded: %s", yara_rules_path)
except Exception as exc:
self.logger.warning("Failed to load YARA rules: %s", exc)
def scan_bytes(self, payload: bytes) -> List[str]:
if self.yara_rules is None:
return []
try:
matches = self.yara_rules.match(data=payload)
return [m.rule for m in matches]
except Exception as exc:
self.logger.debug("YARA byte scan failed: %s", exc)
return []
def scan_file(self, path: str) -> Dict[str, Any]:
result = {"yara": [], "av": "not_run"}
if self.yara_rules is not None:
try:
result["yara"] = [m.rule for m in self.yara_rules.match(path)]
except Exception as exc:
self.logger.debug("YARA file scan failed for %s: %s", path, exc)
av_cmd = None
if shutil.which("clamscan"):
av_cmd = ["clamscan", "--no-summary", path]
elif shutil.which("MpCmdRun.exe"):
av_cmd = ["MpCmdRun.exe", "-Scan", "-ScanType", "3", "-File", path]
if av_cmd:
try:
proc = subprocess.run(av_cmd, capture_output=True, text=True, timeout=30)
output = (proc.stdout or "") + "\n" + (proc.stderr or "")
result["av"] = output.strip()[:1000]
except Exception as exc:
result["av"] = f"scan_failed: {exc}"
return result
class HeuristicDetector:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.conn_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=128))
self.dns_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=128))
@staticmethod
def entropy(payload: bytes) -> float:
if not payload:
return 0.0
counts = Counter(payload)
total = len(payload)
return -sum((c / total) * np.log2(c / total) for c in counts.values() if c > 0) if np is not None else 0.0
def score_dga_like(self, domain: str) -> float:
if not domain:
return 0.0
label = domain.split(".")[0].lower()
if len(label) < 12:
return 0.0
consonants = sum(ch in "bcdfghjklmnpqrstvwxyz" for ch in label)
vowels = sum(ch in "aeiou" for ch in label)
digit_ratio = sum(ch.isdigit() for ch in label) / max(len(label), 1)
weird_ratio = sum(ch not in string.ascii_lowercase + string.digits + "-" for ch in label) / max(len(label), 1)
imbalance = abs(consonants - vowels) / max(len(label), 1)
return min(1.0, digit_ratio * 0.6 + weird_ratio * 0.8 + imbalance)
def track_connect(self, src_ip: str, dst_ip: str) -> float:
now = time.time()
history = self.conn_history[src_ip]
history.append((now, dst_ip))
recent = [entry for entry in history if now - entry[0] <= 60]
unique_targets = len({entry[1] for entry in recent})
return min(1.0, unique_targets / 40.0)
def track_dns(self, src_ip: str, domain: str) -> float:
now = time.time()
history = self.dns_history[src_ip]
history.append((now, domain))
recent = [entry for entry in history if now - entry[0] <= 60]
return min(1.0, len(recent) / 80.0)
class SessionCorrelator:
def __init__(self, logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
self.events: Dict[str, deque] = defaultdict(lambda: deque(maxlen=256))
def add_event(self, host: str, kind: str, details: Dict[str, Any]) -> None:
self.events[host].append({"ts": time.time(), "kind": kind, "details": details})
def detect_kill_chain(self, host: str) -> bool:
ev = list(self.events.get(host, []))
if len(ev) < 3:
return False
kinds = [e["kind"] for e in ev[-20:]]
def idx(name: str) -> int:
try:
return kinds.index(name)
except ValueError:
return -1
i_dns = idx("dns_query")
i_conn = idx("connect")
i_download = idx("download")
i_exec = idx("execute")
return -1 not in (i_dns, i_conn, i_download, i_exec) and i_dns < i_conn < i_download < i_exec
class MLClassifier:
def __init__(self, model_path: str = "model.pkl", logger: Optional[logging.Logger] = None) -> None:
self.logger = logger or logging.getLogger(__name__)
try:
import joblib
self.model = joblib.load(model_path)
self.logger.info("ML model loaded successfully from %s", model_path)
except Exception as e:
self.logger.error("Failed to load ML model from %s: %s", model_path, e)
self.model = None
def extract_features(self, payload: bytes) -> np.ndarray:
length = len(payload)
counts = Counter(payload)
total = length if length > 0 else 1
entropy = -sum((count / total) * np.log2(count / total) for count in counts.values() if count > 0)
return np.array([[length, entropy]])
def classify(self, payload: bytes) -> bool:
features = self.extract_features(payload)
if self.model is not None:
prediction = self.model.predict(features)
self.logger.debug("ML model prediction: %s", prediction)
return bool(prediction[0])
else:
if features[0, 0] > 1000 and features[0, 1] > 7.0:
self.logger.debug("Fallback heuristic: payload marked as malicious.")
return True
return False
@staticmethod
def train_model(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import joblib
logger = logger or logging.getLogger(__name__)
logger.info("Loading dataset from %s", dataset_path)
try:
df = pd.read_csv(dataset_path)
except Exception as e:
logger.error("Failed to load dataset: %s", e)
return
features = []
labels = []
classifier = MLClassifier(logger=logger)
for index, row in df.iterrows():
payload_str = row.get('payload', '')
try:
payload_bytes = bytes.fromhex(payload_str)
except Exception as e:
logger.error("Error converting payload to bytes for row %d: %s", index, e)
continue
feats = classifier.extract_features(payload_bytes)[0]
features.append(feats)
labels.append(row.get('label', 0))
if not features:
logger.error("No valid training samples found.")
return
logger.info("Training model on %d samples", len(features))
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(features, labels)
joblib.dump(clf, model_output_path)
logger.info("Model trained and saved to %s", model_output_path)
def get_local_process_info(port: int) -> str:
try:
import psutil
for conn in psutil.net_connections(kind="inet"):
if conn.laddr and conn.laddr.port == port:
return str(conn.pid) if conn.pid is not None else "N/A"
except ImportError:
return "N/A (psutil not installed)"
return "N/A"
class XDPCollector:
def __init__(self, interface: str, logger: Optional[logging.Logger] = None) -> None:
if BPF is None:
raise ImportError("BCC BPF module is not available. Please install bcc.")
self.logger = logger or logging.getLogger(__name__)
self.interface = interface
self.bpf = BPF(text=self._bpf_program())
func = self.bpf.load_func("xdp_prog", BPF.XDP)
self.bpf.attach_xdp(self.interface, func, 0)
self.logger.info("XDP program attached on interface %s", self.interface)
self.bpf["events"].open_perf_buffer(self._handle_event)
def _bpf_program(self) -> str:
return """
#include <uapi/linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
struct data_t {
u32 pkt_len;
u32 src_ip;
u32 dst_ip;
u8 protocol;
};
BPF_PERF_OUTPUT(events);
int xdp_prog(struct xdp_md *ctx) {
struct data_t data = {};
void *data_end = (void *)(long)ctx->data_end;
void *data_ptr = (void *)(long)ctx->data;
struct ethhdr *eth = data_ptr;
if (data_ptr + sizeof(*eth) > data_end)
return XDP_PASS;
if (eth->h_proto == __constant_htons(ETH_P_IP)) {
struct iphdr *ip = data_ptr + sizeof(*eth);
if ((void*)ip + sizeof(*ip) > data_end)
return XDP_PASS;
data.pkt_len = data_end - data_ptr;
data.src_ip = ip->saddr;
data.dst_ip = ip->daddr;
data.protocol = ip->protocol;
events.perf_submit(ctx, &data, sizeof(data));
}
return XDP_PASS;
}
"""
def _handle_event(self, cpu, data, size):
global last_xdp_metadata
event = self.bpf["events"].event(data)
last_xdp_metadata = {
"pkt_len": event.pkt_len,
"src_ip": event.src_ip,
"dst_ip": event.dst_ip,
"protocol": event.protocol
}
self.logger.debug("XDP event: %s", last_xdp_metadata)
def poll(self):
self.bpf.perf_buffer_poll(timeout=100)
def detach(self):
self.bpf.remove_xdp(self.interface, 0)
def get_xdp_metadata_details() -> str:
global last_xdp_metadata
if last_xdp_metadata:
return (f"XDP Metadata: Packet Length: {last_xdp_metadata.get('pkt_len')}, "
f"Src IP: {last_xdp_metadata.get('src_ip')}, "
f"Dst IP: {last_xdp_metadata.get('dst_ip')}, "
f"Protocol: {last_xdp_metadata.get('protocol')}")
else:
return "No XDP metadata available."
class PacketSniffer:
COMMON_PORTS = {
20: ("FTP Data", "File Transfer Protocol - Data channel"),
21: ("FTP Control", "File Transfer Protocol - Control channel"),
22: ("SSH", "Secure Shell"),
23: ("Telnet", "Telnet protocol"),
25: ("SMTP", "Simple Mail Transfer Protocol"),
53: ("DNS", "Domain Name System"),
67: ("DHCP", "DHCP Server"),
68: ("DHCP", "DHCP Client"),
80: ("HTTP", "Hypertext Transfer Protocol"),
110: ("POP3", "Post Office Protocol"),
119: ("NNTP", "Network News Transfer Protocol"),
123: ("NTP", "Network Time Protocol"),
143: ("IMAP", "Internet Message Access Protocol"),
161: ("SNMP", "Simple Network Management Protocol"),
443: ("HTTPS", "HTTP Secure"),
3306: ("MySQL", "MySQL database service"),
5432: ("PostgreSQL", "PostgreSQL database service"),
3389: ("RDP", "Remote Desktop Protocol")
}
def __init__(self, args: argparse.Namespace) -> None:
self.args = args
self.logger = logging.getLogger("PacketSniffer")
self.flagged_logger = logging.getLogger("FlaggedIPLogger")
self.dns_logger = logging.getLogger("DNSQueryLogger")
self.ip_lookup = IPLookup(ttl=3600, logger=self.logger)
self.dns_parser = DNSParser(logger=self.logger)
self.payload_parser = PayloadParser(logger=self.logger)
self.protocol_decoder = ProtocolDecoder(logger=self.logger)
self.content_decoder = ContentDecoder(logger=self.logger)
self.file_carver = FileObjectCarver(logger=self.logger, output_dir=getattr(self.args, "carve_dir", "carved_objects"))
self.signature_scanner = SignatureScanner(logger=self.logger, yara_rules_path=getattr(self.args, "yara_rules", ""))
self.heuristics = HeuristicDetector(logger=self.logger)
self.correlator = SessionCorrelator(logger=self.logger)
self.reassembler = TCPStreamReassembler(logger=self.logger)
self.async_runner = AsyncRunner()
if self.args.sentry:
self.ml_classifier = MLClassifier(model_path=self.args.model_path, logger=self.logger)
self.xdp_collector = None
if self.args.xdp and BPF is not None:
try:
self.xdp_collector = XDPCollector(self.args.interface, logger=self.logger)
self.xdp_thread = threading.Thread(target=self._poll_xdp, daemon=True)
self.xdp_thread.start()
except Exception as e:
self.logger.exception("Failed to initialize XDPCollector: %s", e)
def _risk_from_signals(self, signature_hits: List[str], heuristic_score: float, ml_flag: bool) -> float:
score = 0.0
score += min(0.6, 0.2 * len(signature_hits))
score += min(0.3, heuristic_score * 0.3)
if ml_flag:
score += 0.3
return min(1.0, score)
def _emit_session_events(self, src_ip: str, dst_ip: str, decoded: Dict[str, Any], carved: Optional[Dict[str, Any]]) -> None:
protocol = decoded.get("protocol", "unknown")
self.correlator.add_event(src_ip, "connect", {"dst_ip": dst_ip, "protocol": protocol})
if protocol == "dns" and decoded.get("query"):
self.correlator.add_event(src_ip, "dns_query", {"query": decoded.get("query")})
if carved is not None:
self.correlator.add_event(src_ip, "download", {"sha256": carved.get("sha256"), "path": carved.get("path")})
if carved.get("container") in ("pe",):
self.correlator.add_event(src_ip, "execute", {"container": carved.get("container")})
def _analyze_reassembled_tcp(self, src_ip: str, src_port: int, dst_ip: str, dst_port: int, tcp_seq: int, payload: bytes) -> None:
if not payload:
return
reassembled = self.reassembler.add_tcp_segment(src_ip, src_port, dst_ip, dst_port, tcp_seq, payload)
if reassembled.get("bytes_added", 0) == 0:
return
flow_key: FlowKey = reassembled["flow_key"]
direction = "c2s" if reassembled.get("is_client_to_server") else "s2c"
stream_payload = reassembled["client_stream"] if direction == "c2s" else reassembled["server_stream"]
tail = stream_payload[-65536:]
decoded = self.protocol_decoder.decode_by_ports_or_signature(src_port, dst_port, tail, direction)
protocol = decoded.get("protocol", "unknown")
signature_hits: List[str] = []
candidate_payload = tail
if protocol == "http1":
candidate_payload = self.content_decoder.decode_http_content(decoded)
else:
b64 = self.content_decoder.try_base64(tail)
if b64:
candidate_payload = b64
signature_hits.extend(self.signature_scanner.scan_bytes(candidate_payload))
carved = self.file_carver.carve(flow_key, candidate_payload, source=direction)
file_scan: Dict[str, Any] = {"yara": [], "av": "not_run"}
if carved is not None:
file_scan = self.signature_scanner.scan_file(carved["path"])
signature_hits.extend(file_scan.get("yara", []))
entropy_score = 1.0 if self.heuristics.entropy(candidate_payload[:4096]) >= 7.3 else 0.0
dga_score = 0.0
if protocol == "dns":
dga_score = self.heuristics.score_dga_like(decoded.get("query", ""))
beacon_score = self.heuristics.track_connect(src_ip, dst_ip)
if protocol == "dns":
_ = self.heuristics.track_dns(src_ip, decoded.get("query", ""))
heuristic_score = min(1.0, max(entropy_score, dga_score, beacon_score))
ml_flag = False
if self.args.sentry:
ml_flag = self.ml_classifier.classify(candidate_payload)
risk = self._risk_from_signals(signature_hits, heuristic_score, ml_flag)
self._emit_session_events(src_ip, dst_ip, decoded, carved)
kill_chain = self.correlator.detect_kill_chain(src_ip)
tls_meta = {}
if protocol == "tls":
tls_meta = {
"sni": decoded.get("sni", ""),
"ja3": decoded.get("ja3", ""),
"ja3_hash": decoded.get("ja3_hash", ""),
}
if risk >= 0.7 or kill_chain:
message = (
"\n====== ADVANCED DETECTION ALERT ======\n"
f"Flow: {src_ip}:{src_port} -> {dst_ip}:{dst_port}\n"
f"Protocol: {protocol}\n"
f"Risk Score: {risk:.2f}\n"
f"Signatures: {signature_hits}\n"
f"Heuristic Score: {heuristic_score:.2f}\n"
f"ML Flag: {ml_flag}\n"
f"TLS Metadata: {tls_meta}\n"
f"Carved Object: {carved}\n"
f"File Scan: {file_scan}\n"
f"Kill Chain Detected: {kill_chain}\n"
"======================================\n"
)
self.flagged_logger.warning(message)
if self.args.sentry:
self.block_ip(src_ip)
def _poll_xdp(self):
while True:
try:
if self.xdp_collector:
self.xdp_collector.poll()
except Exception as e:
self.logger.exception("Error polling XDP events: %s", e)
def identify_application(self, src_port: int, dst_port: int) -> Tuple[str, str]:
for port in (dst_port, src_port):
if port in self.COMMON_PORTS:
return self.COMMON_PORTS[port]
return ("Unknown", "Unknown application protocol")
def block_ip(self, ip: str) -> None:
cmd = f"echo 'Blocking IP {ip}' >> BLOCKER.GARY"
self.logger.info("Blocking IP %s with command: %s", ip, cmd)
os.system(cmd)
def _is_noisy_discovery_traffic(self, packet) -> bool:
if not packet.haslayer(IP):
return False
src_ip = packet[IP].src
dst_ip = packet[IP].dst
try:
src_obj = ipaddress.ip_address(src_ip)
dst_obj = ipaddress.ip_address(dst_ip)
except ValueError:
return False
if not packet.haslayer(UDP):
return False
return src_obj.is_private and (dst_obj.is_multicast or dst_ip == "255.255.255.255")
def log_flagged_ip(self, packet, flagged_signatures: List[Tuple[str, str]],
app_name: str, app_details: str) -> None:
source_ip = "Unknown"
dest_ip = "Unknown"
port_info = ""
process_id = "N/A"
if packet.haslayer(IP):
source_ip = packet[IP].src
dest_ip = packet[IP].dst
elif packet.haslayer(IPv6):
source_ip = packet[IPv6].src
dest_ip = packet[IPv6].dst
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
port_info = f"TCP src: {tcp_layer.sport}, dst: {tcp_layer.dport}"
process_id = get_local_process_info(tcp_layer.dport)
elif packet.haslayer(UDP):
udp_layer = packet[UDP]
port_info = f"UDP src: {udp_layer.sport}, dst: {udp_layer.dport}"
process_id = get_local_process_info(udp_layer.dport)
ip_background = "No background info available."
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(source_ip))
try:
info = future.result(timeout=6)
if info:
ip_background = "\n".join([f"{k.capitalize()}: {v}" for k, v in info.items() if k in ("hostname", "city", "region", "country", "org")])
except Exception as e:
self.logger.exception("Error retrieving IP background info: %s", e)
message = (
"\n====== FLAGGED IP ALERT ======\n"
f"Source IP: {source_ip}\n"
f"Destination IP: {dest_ip}\n"
f"Application: {app_name} ({app_details})\n"
f"Port Info: {port_info}\n"
f"Process ID: {process_id}\n"
f"IP Background:\n{ip_background}\n"
f"Flagged Signatures: {flagged_signatures}\n"
"===============================\n"
)
self.flagged_logger.warning(message)
def parse_payload(self, packet, app_name: str, payload: bytes) -> None:
if self.args.sentry:
if self.ml_classifier.classify(payload):
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown"
self.logger.warning("Sentry mode: payload classified as malicious. Blocking IP %s", src_ip)
self.block_ip(src_ip)
return
self.logger.info("Parsing payload for application: %s", app_name)
flagged_signatures = self.payload_parser.analyze_hex_dump(payload)
if flagged_signatures and not self._is_noisy_discovery_traffic(packet):
app_info = self.identify_application(
packet[TCP].sport if packet.haslayer(TCP) else (packet[UDP].sport if packet.haslayer(UDP) else 0),
packet[TCP].dport if packet.haslayer(TCP) else (packet[UDP].dport if packet.haslayer(UDP) else 0)
)
self.log_flagged_ip(packet, flagged_signatures, app_name, app_info[1])
if app_name == "HTTP":
self.payload_parser.parse_http_payload(payload)
elif app_name == "HTTPS":
self.payload_parser.parse_tls_payload(payload[:64])
elif app_name == "DNS":
self.dns_parser.parse_dns_payload(payload)
else:
self.payload_parser.parse_text_payload(payload)
def packet_handler(self, packet) -> None:
if self.args.expensive and packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
normalized_payload = self.payload_parser.normalize_payload(payload)
hex_payload = payload.hex()
packet_summary = packet.summary()
packet_details = packet.show(dump=True)
self.logger.info("Expensive mode: analyzing packet with LLM.")
label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=self.logger)
if label == 1:
self.logger.warning("Expensive Mode: Packet flagged as malicious by LLM.")
self.log_flagged_ip(packet, flagged_signatures=[], app_name="Expensive Mode", app_details="LLM flagged malicious")
src_ip = packet[IP].src if packet.haslayer(IP) else "unknown"
self.block_ip(src_ip)
else:
self.logger.info("Expensive Mode: Packet deemed benign by LLM.")
return
ip_str = None
if packet.haslayer(IP):
ip_str = packet[IP].src
ip_obj = ipaddress.ip_address(ip_str)
if self.args.local_only and not ip_obj.is_private:
return
if not self.args.local_only and ip_obj.is_private:
return
elif packet.haslayer(IPv6):
ip_str = packet[IPv6].src
ip_obj = ipaddress.ip_address(ip_str)
if self.args.local_only and not ip_obj.is_private:
return
if not self.args.local_only and ip_obj.is_private:
return
else:
self.logger.warning("Packet without IP/IPv6 layer")
return
if not self.args.local_only and self.args.red_alert:
future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(ip_str))
try:
ip_info = future.result(timeout=6)
country = ip_info.get("country", "").upper()
if country in self.args.whitelist:
self.logger.info("Skipping packet from whitelisted country: %s", country)
return
except Exception as e:
self.logger.exception("Error during red alert IP filtering: %s", e)
self.logger.info("=" * 80)
self.logger.info("Packet: %s", packet.summary())
summary_str = packet.summary()
dns_match = re.search(r"DNS Qry b'([^']+)'", summary_str)
if dns_match:
dns_query = dns_match.group(1)
if not self.dns_parser.is_blacklisted(dns_query):
self.dns_logger.info("DNS Query (fallback): %s", dns_query)
if packet.haslayer(DNS):
dns_layer = packet[DNS]
if dns_layer.qr == 0 and dns_layer.qd is not None:
try:
if isinstance(dns_layer.qd, DNSQR):
dns_query = (dns_layer.qd.qname.decode()
if isinstance(dns_layer.qd.qname, bytes)
else dns_layer.qd.qname)
else:
dns_query = ", ".join(
q.qname.decode() if isinstance(q.qname, bytes) else q.qname
for q in dns_layer.qd
)
except Exception as e:
dns_query = str(dns_layer.qd)
self.logger.info("DNS Query (from DNS layer): %s", dns_query)
try:
raw_dns_payload = bytes(dns_layer)
self.dns_parser.parse_dns_payload(raw_dns_payload)
except Exception as e:
self.logger.exception("Error processing DNS layer: %s", e)
if packet.haslayer(Ether):
eth = packet[Ether]
self.logger.info("Ethernet: src=%s, dst=%s, type=0x%04x", eth.src, eth.dst, eth.type)
else:
self.logger.warning("No Ethernet layer found.")
return
if packet.haslayer(ARP):
arp = packet[ARP]
self.logger.info("ARP: op=%s, src=%s, dst=%s", arp.op, arp.psrc, arp.pdst)
return
if packet.haslayer(IP):
ip_layer = packet[IP]
self.logger.info("IPv4: src=%s, dst=%s, ttl=%s, proto=%s",
ip_layer.src, ip_layer.dst, ip_layer.ttl, ip_layer.proto)
elif packet.haslayer(IPv6):
ip_layer = packet[IPv6]
self.logger.info("IPv6: src=%s, dst=%s, hlim=%s",
ip_layer.src, ip_layer.dst, ip_layer.hlim)
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
self.logger.info("TCP: sport=%s, dport=%s", tcp_layer.sport, tcp_layer.dport)
app_name, app_details = self.identify_application(tcp_layer.sport, tcp_layer.dport)
self.logger.info("Identified Application: %s (%s)", app_name, app_details)
if packet.haslayer(Raw):
tcp_payload = bytes(packet[Raw].load)
self._analyze_reassembled_tcp(
src_ip=ip_layer.src,
src_port=tcp_layer.sport,
dst_ip=ip_layer.dst,
dst_port=tcp_layer.dport,
tcp_seq=int(tcp_layer.seq),
payload=tcp_payload,
)
if app_name == "HTTPS" or (TLS and packet.haslayer(TLS)):
if TLS and packet.haslayer(TLS):
tls_layer = packet[TLS]
self.logger.info("TLS Record: %s", tls_layer.summary())
if packet.haslayer(TLSClientHello):
client_hello = packet[TLSClientHello]
self.logger.info("TLS ClientHello: %s", client_hello.summary())
if hasattr(client_hello, 'servernames'):
self.logger.info("SNI: %s", client_hello.servernames)
else:
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
self.payload_parser.parse_tls_payload(payload)
keylog = os.getenv("SSLKEYLOGFILE", "")
if keylog:
self.logger.info("TLS plaintext decryption can use SSLKEYLOGFILE at: %s", keylog)
else:
self.logger.info("TLS inspection limited to metadata (SNI/JA3/cert) unless SSLKEYLOGFILE or interception is configured.")
else:
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
self.parse_payload(packet, app_name, payload)
elif packet.haslayer(UDP):
udp_layer = packet[UDP]
self.logger.info("UDP: sport=%s, dport=%s", udp_layer.sport, udp_layer.dport)
app_name, app_details = self.identify_application(udp_layer.sport, udp_layer.dport)
self.logger.info("Identified Application: %s (%s)", app_name, app_details)
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
if app_name == "DNS":
parsed_dns = self.protocol_decoder.decode_dns(payload)
self.correlator.add_event(ip_layer.src, "dns_query", {"query": parsed_dns.get("query", "")})
self.parse_payload(packet, app_name, payload)
elif packet.haslayer(ICMP):
icmp_layer = packet[ICMP]
self.logger.info("ICMP: type=%s, code=%s", icmp_layer.type, icmp_layer.code)
else:
self.logger.warning("Unsupported transport layer.")
def run(self) -> None:
self.logger.info("Starting Enhanced Packet Sniffer on interface '%s'", self.args.interface)
backend_ok, backend_msg = check_packet_capture_backend(logger=self.logger)
if not backend_ok:
self.logger.error("%s", backend_msg)
return
try:
bpf_filter = "ip" if not self.args.local_only else ""
sniff(iface=self.args.interface, prn=self.packet_handler, store=0, filter=bpf_filter)
except KeyboardInterrupt:
self.logger.info("Stopping packet capture (KeyboardInterrupt received)...")
except Exception as e:
self.logger.exception("Error during packet capture: %s", e)
async def async_llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int:
"""
Uses the AsyncOpenAI client to determine if a packet is malicious or benign.
Returns 1 if the answer indicates 'malicious', 0 otherwise.
"""
logger = logger or logging.getLogger(__name__)
if AsyncOpenAI is None:
logger.error("AsyncOpenAI client is not available. Please install the openai package (>=1.0.0).")
return 0
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.error("OpenAI API key not found in environment (OPENAI_API_KEY).")
return 0
client = AsyncOpenAI(api_key=api_key)
xdp_details = get_xdp_metadata_details()
prompt = (
f"Examine the following packet details and decide if the packet is malicious or benign.\n\n"
f"Packet Summary:\n{packet_summary}\n\n"
f"Packet Detailed Info (Scapy dump):\n{packet_details}\n\n"
f"Normalized Payload (first 300 characters):\n{normalized_payload[:300]}\n\n"
f"Hex Payload (first 200 characters):\n{hex_payload[:200]}\n\n"
f"{xdp_details}\n\n"
"Answer with a single word: 'malicious' or 'benign'."
)
logger.debug("LLM Prompt:\n%s", prompt)
try:
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a network security analyst."},
{"role": "user", "content": prompt}
],
max_tokens=10,
temperature=0
)
answer = response.choices[0].message.content.strip().lower()
logger.debug("LLM response: %s", answer)
return 1 if "malicious" in answer else 0
except Exception as e:
logger.exception("Error querying OpenAI API: %s", e)
return 0
def llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int:
"""
Synchronous wrapper for async_llm_label_packet.
"""
logger = logger or logging.getLogger(__name__)
new_loop = asyncio.new_event_loop()
try:
result = new_loop.run_until_complete(
async_llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger)
)
return result
except Exception as e:
logger.exception("Error in LLM labeling: %s", e)
return 0
finally:
new_loop.run_until_complete(new_loop.shutdown_asyncgens())
new_loop.close()
def build_dataset_main(interface: str, num_samples: int, output_path: str, use_llm: bool, logger: Optional[logging.Logger] = None) -> None:
"""
Captures packets with a Raw payload and labels them.
If use_llm is True, the LLM (via the async client) is used for automatic labeling.
The results are saved to a CSV file.
"""
logger = logger or logging.getLogger(__name__)
logger.info("Starting dataset capture on interface %s; capturing %d samples.", interface, num_samples)
samples = []
def packet_callback(packet):
if packet.haslayer(Raw):
payload = bytes(packet[Raw].load)
hex_payload = payload.hex()
normalized_payload = PayloadParser(logger=logger).normalize_payload(payload)
packet_summary = packet.summary()
packet_details = packet.show(dump=True)
print("\nPacket captured:")
print(packet_summary)
print("Packet Detailed Info:")
print(packet_details)
print("Normalized Payload (first 40 characters):", normalized_payload[:40])
if use_llm:
label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger)
print(f"LLM labeled this packet as: {'malicious' if label == 1 else 'benign'}")
else:
user_input = input("Label this packet as malicious (1) or benign (0) [default=0]: ").strip()
label = int(user_input) if user_input in ["0", "1"] else 0
samples.append({"payload": hex_payload, "label": label})
if len(samples) >= num_samples:
return True
return False
scapy.sniff(iface=interface, prn=packet_callback, store=0, timeout=60)
if samples:
try:
with open(output_path, "w", newline="") as csvfile:
fieldnames = ["payload", "label"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for sample in samples:
writer.writerow(sample)
logger.info("Dataset built and saved to %s (%d samples).", output_path, len(samples))
except Exception as e:
logger.error("Failed to write dataset to %s: %s", output_path, e)
else:
logger.error("No samples were captured.")
def train_model_main(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None:
logger = logger or logging.getLogger(__name__)
MLClassifier.train_model(dataset_path, model_output_path, logger=logger)
def get_missing_runtime_dependencies() -> List[str]:
missing = []
dependency_map = [
("aiohttp", aiohttp),
("PyYAML", yaml),
("cachetools", TTLCache),
("numpy", np),
("scapy", scapy),
]
for package_name, module_ref in dependency_map:
if module_ref is None:
missing.append(package_name)
return missing
def detect_operating_system() -> str:
system_name = platform.system().lower()
if system_name.startswith("win"):
return "windows"
if system_name.startswith("linux"):
return "linux"
if system_name.startswith("darwin"):
return "macos"
return "other"
def get_required_pip_packages(os_name: str) -> List[str]:
base_packages = [
"aiohttp",
"PyYAML",
"cachetools",
"numpy",
"scapy",
"yara-python",
"openai",
"psutil",
"pandas",
"scikit-learn",
"joblib",
]
if os_name == "linux":
return [*base_packages, "bcc"]
return base_packages
def install_npcap_windows(logger: Optional[logging.Logger] = None) -> bool:
logger = logger or logging.getLogger(__name__)
backend_ok, _ = check_packet_capture_backend(logger=logger)
if backend_ok:
logger.info("Npcap already available on this Windows machine.")
return True
npcap_url = "https://npcap.com/dist/npcap-1.87.exe"
installer_path = os.path.join(os.environ.get("TEMP", "."), "npcap-1.87.exe")
logger.info("Npcap not detected. Downloading official installer from %s", npcap_url)
try:
from urllib.request import urlretrieve
urlretrieve(npcap_url, installer_path)
except Exception as exc:
logger.error("Failed to download Npcap installer: %s", exc)
return False
logger.info("Launching Npcap installer (may prompt for UAC/admin approval)...")
try:
subprocess.run([installer_path], check=True)
except subprocess.CalledProcessError as exc:
logger.error("Npcap installer exited with code %s", exc.returncode)
return False
except Exception as exc:
logger.error("Failed to launch Npcap installer: %s", exc)
return False
backend_ok, backend_msg = check_packet_capture_backend(logger=logger)
if not backend_ok:
logger.error("Npcap still not available after installer run: %s", backend_msg)
return False
logger.info("Npcap installation verified successfully.")
return True
def install_dependencies(logger: Optional[logging.Logger] = None) -> bool:
logger = logger or logging.getLogger(__name__)
os_name = detect_operating_system()
packages = get_required_pip_packages(os_name)
logger.info("Detected operating system: %s", os_name)
logger.info("Installing OS-specific Python dependencies...")
logger.info("Installing dependencies with pip3...")
pip_install_succeeded = False
try:
subprocess.run(["pip3", "install", *packages], check=True)
logger.info("Dependencies installed successfully via pip3.")
pip_install_succeeded = True
except FileNotFoundError:
logger.warning("pip3 not found. Falling back to python -m pip.")
except subprocess.CalledProcessError as exc:
logger.warning("pip3 install returned an error (%s). Falling back to python -m pip.", exc.returncode)
if not pip_install_succeeded:
try:
subprocess.run([sys.executable, "-m", "pip", "install", *packages], check=True)
logger.info("Dependencies installed successfully via python -m pip.")
pip_install_succeeded = True
except subprocess.CalledProcessError as exc:
logger.error("Dependency installation failed (exit code %s).", exc.returncode)
return False
except Exception as exc:
logger.error("Dependency installation failed: %s", exc)
return False
if not pip_install_succeeded:
logger.error("Dependency installation did not complete successfully.")
return False
if os_name == "windows":
if not install_npcap_windows(logger=logger):
logger.error("Windows install phase incomplete: Npcap setup failed.")
return False
return True
def is_likely_loopback(interface_name: str) -> bool:
lowered = interface_name.lower()
loopback_markers = ["loopback", "lo", "npcap loopback", "software loopback"]
return any(marker in lowered for marker in loopback_markers)
def auto_detect_interface(logger: Optional[logging.Logger] = None) -> str:
logger = logger or logging.getLogger(__name__)
try:
default_iface = str(scapy.conf.iface)
if default_iface and not is_likely_loopback(default_iface):
logger.info("Auto-detected interface from Scapy default: %s", default_iface)
return default_iface
except Exception:
logger.debug("Unable to use Scapy default interface for auto-detection.")
try:
routed_iface = scapy.conf.route.route("8.8.8.8")[0]
if routed_iface and not is_likely_loopback(routed_iface):
logger.info("Auto-detected interface from routing table: %s", routed_iface)
return routed_iface
except Exception:
logger.debug("Unable to use routing table for interface auto-detection.")
try:
interfaces = scapy.get_if_list()
for iface in interfaces:
if iface and not is_likely_loopback(iface):
logger.info("Auto-detected interface from interface list: %s", iface)
return iface
if interfaces:
logger.info("Falling back to first available interface: %s", interfaces[0])
return interfaces[0]
except Exception as exc:
logger.debug("Failed to list interfaces during auto-detection: %s", exc)
raise RuntimeError("Could not auto-detect a usable network interface.")
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Enhanced Packet Sniffer with Red Alert, Whitelist, Sentry Mode, Dataset Building, Training, and LLM-assisted Labeling"
)
parser.add_argument("-i", "--interface", type=str, default="",
help="Network interface to sniff on")
parser.add_argument("--install", action="store_true",
help="Install required Python dependencies using pip3 and exit")
parser.add_argument("--start", action="store_true",
help="Start sniffing with automatic interface detection when --interface is not provided")
parser.add_argument("-l", "--logfile", type=str, default="sniffer.log",
help="Path to the main log file")
parser.add_argument("--no-bgcheck", action="store_true",
help="Disable IP background lookup")
parser.add_argument("--local-only", action="store_true",
help="Capture only local (private) traffic")
parser.add_argument("--red-alert", action="store_true",
help="Enable red alert mode: only log packets from non-allied countries")
parser.add_argument("--whitelist", type=str, default="US",
help="Comma-separated list of allied (whitelisted) country codes (default: US)")
parser.add_argument("--sentry", action="store_true",
help="Enable sentry mode (ML-based blocking of malicious packets)")
parser.add_argument("--model-path", type=str, default="model.pkl",
help="Path to the ML model file (used in sentry mode)")
parser.add_argument("--train", action="store_true",
help="Run training mode to build a new ML model from a dataset")
parser.add_argument("--dataset", type=str, default="",
help="Path to the CSV dataset for training")
parser.add_argument("--build-dataset", action="store_true",
help="Capture and build a labeled dataset interactively")
parser.add_argument("--num-samples", type=int, default=10,
help="Number of samples to capture when building the dataset")
parser.add_argument("--dataset-out", type=str, default="training_data.csv",
help="Path to save the built dataset CSV file")
parser.add_argument("--llm-label", action="store_true",
help="Use the OpenAI API to automatically label packets in dataset building mode")
parser.add_argument("--xdp", action="store_true",
help="Enable XDP (Express Data Path) to gather additional packet metadata")
parser.add_argument("--expensive", action="store_true",
help="Use LLM to analyze all packets in realtime (expensive mode)")
parser.add_argument("--yara-rules", type=str, default="",
help="Path to YARA rules file for signature scanning")
parser.add_argument("--carve-dir", type=str, default="carved_objects",
help="Directory where carved files/objects are written")
parser.add_argument("-v", "--verbosity", type=str, default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR)")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
args.whitelist = set(code.strip().upper() for code in args.whitelist.split(','))
logging_level = getattr(logging, args.verbosity.upper(), logging.INFO)
setup_logging("logging.yaml")
logging.getLogger().setLevel(logging_level)
logger = logging.getLogger(__name__)
if args.install:
ok = install_dependencies(logger=logger)
raise SystemExit(0 if ok else 1)
missing_runtime_deps = get_missing_runtime_dependencies()
if missing_runtime_deps:
logger.error(
"Missing required dependencies: %s. Run with --install first.",
", ".join(missing_runtime_deps)
)
raise SystemExit(1)
if args.build_dataset:
if not args.interface and args.start:
try:
args.interface = auto_detect_interface(logger=logger)
except Exception as exc:
logger.error("--start failed to auto-detect an interface: %s", exc)
raise SystemExit(1)
if not args.interface:
logger.error("Dataset mode requires --interface or --start for auto-detection")
raise SystemExit(1)
build_dataset_main(args.interface, args.num_samples, args.dataset_out, args.llm_label, logger=logger)
elif args.train:
if not args.dataset:
logger.error("Training mode requires a dataset. Please provide --dataset <path>")
else:
train_model_main(args.dataset, args.model_path, logger=logger)
else:
if args.start and not args.interface:
try:
args.interface = auto_detect_interface(logger=logger)
except Exception as exc:
logger.error("--start failed to auto-detect an interface: %s", exc)
raise SystemExit(1)
if not args.interface:
logger.error("Sniffer mode requires --interface or --start for auto-detection")
raise SystemExit(1)
sniffer = PacketSniffer(args)
sniffer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment