Last active
February 28, 2026 12:11
-
-
Save gary23w/3bbb51566a7ba1a7bbc1a7e2d384b8b6 to your computer and use it in GitHub Desktop.
simple-packet-sniffer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| simplepacketsniffer.py 0.4 | |
| Usage examples: | |
| Dataset Building Mode (with LLM labeling): | |
| sudo python simplepacketsniffer.py --build-dataset --llm-label --num-samples 10 --dataset-out training_data.csv -i INTERFACE -v INFO [--xdp] | |
| Training Mode: | |
| sudo python simplepacketsniffer.py --train --dataset training_data.csv --model-path model.pkl -v INFO | |
| Sniffer (Protection) Mode: | |
| sudo python simplepacketsniffer.py -i INTERFACE --red-alert --whitelist US,CA --sentry --model-path model.pkl -v DEBUG [--xdp] | |
| Expensive Mode (Real-time LLM analysis for every packet): | |
| sudo python simplepacketsniffer.py --expensive -i INTERFACE -v DEBUG [--xdp] | |
| ***logging.yaml | |
| version: 1 | |
| disable_existing_loggers: false | |
| formatters: | |
| standard: | |
| format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| simple: | |
| format: '%(asctime)s - %(levelname)s - %(message)s' | |
| handlers: | |
| console: | |
| class: logging.StreamHandler | |
| level: DEBUG | |
| formatter: standard | |
| stream: ext://sys.stdout | |
| file_handler: | |
| class: logging.handlers.RotatingFileHandler | |
| level: DEBUG | |
| formatter: standard | |
| filename: sniffer.log | |
| maxBytes: 1048576 | |
| backupCount: 3 | |
| flagged_ip_handler: | |
| class: logging.handlers.RotatingFileHandler | |
| level: WARNING | |
| formatter: standard | |
| filename: flagged_ips.log | |
| maxBytes: 1048576 | |
| backupCount: 3 | |
| dns_handler: | |
| class: logging.handlers.RotatingFileHandler | |
| level: INFO | |
| formatter: standard | |
| filename: dns_queries.log | |
| maxBytes: 1048576 | |
| backupCount: 3 | |
| loggers: | |
| PacketSniffer: | |
| level: DEBUG | |
| handlers: [console, file_handler] | |
| propagate: no | |
| FlaggedIPLogger: | |
| level: WARNING | |
| handlers: [flagged_ip_handler, console] | |
| propagate: no | |
| DNSQueryLogger: | |
| level: INFO | |
| handlers: [dns_handler, console] | |
| propagate: no | |
| root: | |
| level: DEBUG | |
| handlers: [console] | |
| *** | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import logging | |
| import logging.config | |
| import subprocess | |
| import sys | |
| import platform | |
| import re | |
| import ipaddress | |
| import time | |
| import threading | |
| import struct | |
| import os | |
| import csv | |
| import string | |
| import gzip | |
| import zlib | |
| import hashlib | |
| import binascii | |
| import shutil | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from collections import Counter | |
| from collections import defaultdict, deque | |
| from typing import Tuple, Dict, Any, List, Optional | |
| try: | |
| import aiohttp | |
| except ImportError: | |
| aiohttp = None | |
| try: | |
| import yaml | |
| except ImportError: | |
| yaml = None | |
| try: | |
| from cachetools import TTLCache | |
| except ImportError: | |
| TTLCache = None | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| np = None | |
| try: | |
| import scapy.all as scapy | |
| from scapy.all import sniff, Ether, IP, IPv6, TCP, UDP, ICMP, ARP, Raw | |
| from scapy.layers.dns import DNS, DNSQR | |
| except ImportError: | |
| scapy = None | |
| sniff = None | |
| Ether = None | |
| IP = None | |
| IPv6 = None | |
| TCP = None | |
| UDP = None | |
| ICMP = None | |
| ARP = None | |
| Raw = None | |
| DNS = None | |
| DNSQR = None | |
| try: | |
| from scapy.layers.tls.all import TLS, TLSClientHello, TLSHandshake | |
| except ImportError: | |
| TLS = None | |
| TLSClientHello = None | |
| TLSHandshake = None | |
| try: | |
| from openai import AsyncOpenAI | |
| except ImportError: | |
| AsyncOpenAI = None | |
| try: | |
| from bcc import BPF | |
| except ImportError: | |
| BPF = None | |
| try: | |
| import yara | |
| except ImportError: | |
| yara = None | |
| last_xdp_metadata: Optional[Dict[str, Any]] = None | |
| def setup_logging(config_file: str = "logging.yaml") -> None: | |
| try: | |
| if hasattr(sys.stdout, "reconfigure"): | |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") | |
| if hasattr(sys.stderr, "reconfigure"): | |
| sys.stderr.reconfigure(encoding="utf-8", errors="replace") | |
| except Exception: | |
| pass | |
| try: | |
| if yaml is None: | |
| raise ImportError("PyYAML is not installed") | |
| with open(config_file, "r") as f: | |
| config = yaml.safe_load(f.read()) | |
| logging.config.dictConfig(config) | |
| except Exception as e: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logging.getLogger(__name__).warning( | |
| "Failed to load logging config from %s, using basic config: %s", | |
| config_file, e | |
| ) | |
| def check_packet_capture_backend(logger: Optional[logging.Logger] = None) -> Tuple[bool, str]: | |
| logger = logger or logging.getLogger(__name__) | |
| if scapy is None: | |
| return False, "Scapy is not installed. Run with --install first." | |
| if os.name == "nt": | |
| use_pcap = bool(getattr(scapy.conf, "use_pcap", False)) | |
| l2listen_str = str(getattr(scapy.conf, "L2listen", "")) | |
| if (not use_pcap) or ("_NotAvailableSocket" in l2listen_str) or ("wpcap.dll missing" in l2listen_str.lower()): | |
| return ( | |
| False, | |
| "Npcap/WinPcap is not installed or not available. Install Npcap from https://npcap.com/ and rerun." | |
| ) | |
| return True, "" | |
| class AsyncRunner: | |
| def __init__(self) -> None: | |
| self.loop = asyncio.new_event_loop() | |
| self.thread = threading.Thread(target=self._run_loop, daemon=True) | |
| self.thread.start() | |
| def _run_loop(self) -> None: | |
| asyncio.set_event_loop(self.loop) | |
| self.loop.run_forever() | |
| def run_coroutine(self, coro): | |
| return asyncio.run_coroutine_threadsafe(coro, self.loop) | |
| class IPLookup: | |
| def __init__(self, ttl: int = 3600, cache_size: int = 1000, logger: Optional[logging.Logger] = None) -> None: | |
| if TTLCache is None or aiohttp is None: | |
| raise ImportError("IPLookup requires 'cachetools' and 'aiohttp'.") | |
| self.cache = TTLCache(maxsize=cache_size, ttl=ttl) | |
| self.logger = logger or logging.getLogger(__name__) | |
| async def fetch_ip_info(self, session: aiohttp.ClientSession, ip: str) -> Dict[str, Any]: | |
| if ip in self.cache: | |
| return self.cache[ip] | |
| try: | |
| async with session.get(f"https://ipinfo.io/{ip}/json", timeout=5) as resp: | |
| if resp.status == 200: | |
| data = await resp.json() | |
| self.cache[ip] = data | |
| return data | |
| else: | |
| self.logger.warning("IP lookup for %s returned status %s", ip, resp.status) | |
| except aiohttp.ClientError as e: | |
| self.logger.exception("Network error during IP lookup for %s: %s", ip, e) | |
| except Exception: | |
| self.logger.exception("Unexpected error during IP lookup for %s", ip) | |
| return {} | |
| async def get_ip_info(self, ip: str) -> Dict[str, Any]: | |
| async with aiohttp.ClientSession() as session: | |
| return await self.fetch_ip_info(session, ip) | |
| class DNSParser: | |
| def __init__(self, logger: Optional[logging.Logger] = None, blacklist: Optional[List[str]] = None) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.blacklist = set(blacklist or ['ipinfo.io', 'ipinfo.io.']) # TODO: add wildcard domains | |
| def is_blacklisted(self, domain: str) -> bool: | |
| return domain in self.blacklist | |
| def parse_dns_name(self, payload: bytes, offset: int) -> Tuple[str, int]: | |
| labels = [] | |
| while True: | |
| if offset >= len(payload): | |
| break | |
| length = payload[offset] | |
| if (length & 0xC0) == 0xC0: | |
| if offset + 1 >= len(payload): | |
| break | |
| pointer = ((length & 0x3F) << 8) | payload[offset + 1] | |
| pointed_name, _ = self.parse_dns_name(payload, pointer) | |
| labels.append(pointed_name) | |
| offset += 2 | |
| break | |
| if length == 0: | |
| offset += 1 | |
| break | |
| offset += 1 | |
| label = payload[offset: offset + length].decode('utf-8', errors='replace') | |
| labels.append(label) | |
| offset += length | |
| domain_name = ".".join(labels) | |
| return domain_name, offset | |
| def parse_dns_payload(self, payload: bytes) -> None: | |
| self.logger.info("Parsing DNS payload...") | |
| if len(payload) < 12: | |
| self.logger.warning("DNS payload too short to parse.") | |
| return | |
| try: | |
| transaction_id, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12]) | |
| self.logger.info("DNS Header: ID=%#04x, Flags=%#04x, QD=%d, AN=%d, NS=%d, AR=%d", | |
| transaction_id, flags, qdcount, ancount, nscount, arcount) | |
| except Exception as e: | |
| self.logger.exception("Error parsing DNS header: %s", e) | |
| return | |
| offset = 12 | |
| for i in range(qdcount): | |
| try: | |
| domain, offset = self.parse_dns_name(payload, offset) | |
| if offset + 4 > len(payload): | |
| self.logger.warning("DNS question truncated.") | |
| break | |
| qtype, qclass = struct.unpack("!HH", payload[offset:offset + 4]) | |
| offset += 4 | |
| self.logger.info("DNS Question %d: %s, type: %d, class: %d", i + 1, domain, qtype, qclass) | |
| except Exception as e: | |
| self.logger.exception("Error parsing DNS question: %s", e) | |
| break | |
| class PayloadParser: | |
| def __init__(self, logger: Optional[logging.Logger] = None) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.SIGNATURES = { | |
| "504B0304": "ZIP archive / Office Open XML document (DOCX, XLSX, PPTX)", | |
| "504B030414000600": "Office Open XML (DOCX, XLSX, PPTX) extended header", | |
| "1F8B08": "GZIP archive", | |
| "377ABCAF271C": "7-Zip archive", | |
| "52617221": "RAR archive", | |
| "425A68": "BZIP2 archive", | |
| "213C617263683E0A": "Ar (UNIX archive) / Debian package", | |
| "7F454C46": "ELF executable (Unix/Linux)", | |
| "4D5A": "Windows executable (EXE, MZ header / DLL)", | |
| "CAFEBABE": "Java class file or Mach-O Fat Binary (ambiguous)", | |
| "FEEDFACE": "Mach-O executable (32-bit, little-endian)", | |
| "CEFAEDFE": "Mach-O executable (32-bit, big-endian)", | |
| "FEEDFACF": "Mach-O executable (64-bit, little-endian)", | |
| "CFFAEDFE": "Mach-O executable (64-bit, big-endian)", | |
| "BEBAFECA": "Mach-O Fat Binary (little endian)", | |
| "4C000000": "Windows shortcut file (.lnk)", | |
| "4D534346": "Microsoft Cabinet file (CAB)", | |
| "D0CF11E0": "Microsoft Office legacy format (DOC, XLS, PPT)", | |
| "25504446": "PDF document", | |
| "7B5C727466": "RTF document (starting with '{\\rtf')", | |
| "3C3F786D6C": "XML file (<?xml)", | |
| "3C68746D6C3E": "HTML file", | |
| "252150532D41646F6265": "PostScript/EPS document (starts with '%!PS-Adobe')", | |
| "4D2D2D2D": "PostScript file (---)", | |
| "89504E47": "PNG image", | |
| "47494638": "GIF image", | |
| "FFD8FF": "JPEG image (general)", | |
| "FFD8FFE0": "JPEG image (JFIF)", | |
| "FFD8FFE1": "JPEG image (EXIF)", | |
| "424D": "Bitmap image (BMP)", | |
| "49492A00": "TIFF image (little endian / Intel)", | |
| "4D4D002A": "TIFF image (big endian / Motorola)", | |
| "38425053": "Adobe Photoshop document (PSD)", | |
| "00000100": "ICO icon file", | |
| "00000200": "CUR cursor file", | |
| "494433": "MP3 audio (ID3 tag)", | |
| "000001BA": "MPEG video (VCD)", | |
| "000001B3": "MPEG video", | |
| "66747970": "MP4/MOV file (ftyp)", | |
| "4D546864": "MIDI file", | |
| "464F524D": "AIFF audio file", | |
| "52494646": "AVI file (RIFF) [Also used in WAV]", | |
| "664C6143": "FLAC audio file", | |
| "4F676753": "OGG container file (OggS)", | |
| "53514C69": "SQLite database file (SQLite format 3)", | |
| "420D0D0A": "Python compiled file (.pyc) [example magic, may vary]", | |
| "6465780A": "Android Dalvik Executable (DEX) file", | |
| "EDABEEDB": "RPM package file", | |
| "786172210D0A1A0A": "XAR archive (macOS installer package)", | |
| } | |
| def normalize_payload(self, payload: bytes) -> str: | |
| """ | |
| Attempt to decode the payload as UTF-8 text and replace non-printable characters. | |
| If decoding fails or results in gibberish, fallback to a hex representation. | |
| """ | |
| try: | |
| text = payload.decode('utf-8', errors='ignore') | |
| normalized = ''.join(ch if ch in string.printable else '.' for ch in text) | |
| if sum(1 for ch in normalized if ch == '.') > len(normalized) * 0.5: | |
| return payload.hex() | |
| return normalized | |
| except Exception: | |
| return payload.hex() | |
| def parse_http_payload(self, payload: bytes) -> None: | |
| try: | |
| text = payload.decode('utf-8', errors='replace') | |
| text = self._make_log_safe(text) | |
| self.logger.info("HTTP Payload:\n%s", text) | |
| except Exception as e: | |
| self.logger.exception("Error parsing HTTP payload: %s", e) | |
| def parse_text_payload(self, payload: bytes) -> None: | |
| try: | |
| text = payload.decode('utf-8', errors='replace') | |
| text = self._make_log_safe(text) | |
| self.logger.info("Text Payload:\n%s", text) | |
| except Exception as e: | |
| self.logger.exception("Error decoding text payload: %s", e) | |
| def parse_tls_payload(self, payload: bytes) -> None: | |
| self.logger.info("TLS Payload (hex):\n%s", payload.hex()) | |
| def analyze_hex_dump(self, payload: bytes) -> List[Tuple[str, str]]: | |
| head_hex = payload[:23].hex().upper() | |
| full_hex = payload.hex().upper() | |
| self.logger.info("Analyzing payload hex dump for signatures...") | |
| found = [] | |
| for sig, desc in self.SIGNATURES.items(): | |
| if len(sig) <= 8: | |
| matched = head_hex.startswith(sig) | |
| else: | |
| matched = head_hex.startswith(sig) or sig in full_hex | |
| if matched: | |
| found.append((sig, desc)) | |
| self.logger.warning("Detected signature %s: %s", sig, desc) | |
| return found | |
| def _make_log_safe(self, text: str) -> str: | |
| target_encoding = getattr(sys.stdout, "encoding", None) or "utf-8" | |
| return text.encode(target_encoding, errors="replace").decode(target_encoding, errors="replace") | |
| @dataclass(frozen=True) | |
| class FlowKey: | |
| src_ip: str | |
| src_port: int | |
| dst_ip: str | |
| dst_port: int | |
| proto: str | |
| def canonical_flow_key(src_ip: str, src_port: int, dst_ip: str, dst_port: int, proto: str = "TCP") -> Tuple[FlowKey, bool]: | |
| left = (src_ip, src_port) | |
| right = (dst_ip, dst_port) | |
| if left <= right: | |
| return FlowKey(src_ip, src_port, dst_ip, dst_port, proto), True | |
| return FlowKey(dst_ip, dst_port, src_ip, src_port, proto), False | |
| @dataclass | |
| class StreamDirectionState: | |
| base_seq: Optional[int] = None | |
| next_seq: Optional[int] = None | |
| contiguous: bytearray = field(default_factory=bytearray) | |
| out_of_order: Dict[int, bytes] = field(default_factory=dict) | |
| def add(self, seq: int, payload: bytes) -> int: | |
| if not payload: | |
| return 0 | |
| if self.base_seq is None: | |
| self.base_seq = seq | |
| self.next_seq = seq | |
| appended = 0 | |
| if self.next_seq is None: | |
| self.next_seq = seq | |
| if seq < self.next_seq: | |
| overlap = self.next_seq - seq | |
| if overlap >= len(payload): | |
| return 0 | |
| payload = payload[overlap:] | |
| seq = self.next_seq | |
| if seq == self.next_seq: | |
| self.contiguous.extend(payload) | |
| appended += len(payload) | |
| self.next_seq += len(payload) | |
| while self.next_seq in self.out_of_order: | |
| chunk = self.out_of_order.pop(self.next_seq) | |
| self.contiguous.extend(chunk) | |
| appended += len(chunk) | |
| self.next_seq += len(chunk) | |
| return appended | |
| existing = self.out_of_order.get(seq) | |
| if existing is None or len(payload) > len(existing): | |
| self.out_of_order[seq] = payload | |
| return 0 | |
| @dataclass | |
| class StreamState: | |
| client_to_server: StreamDirectionState = field(default_factory=StreamDirectionState) | |
| server_to_client: StreamDirectionState = field(default_factory=StreamDirectionState) | |
| created_at: float = field(default_factory=time.time) | |
| last_seen: float = field(default_factory=time.time) | |
| class TCPStreamReassembler: | |
| def __init__(self, logger: Optional[logging.Logger] = None, max_streams: int = 10000, stream_ttl: int = 1800) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.max_streams = max_streams | |
| self.stream_ttl = stream_ttl | |
| self.streams: Dict[FlowKey, StreamState] = {} | |
| def _evict_old(self) -> None: | |
| now = time.time() | |
| stale_keys = [k for k, v in self.streams.items() if now - v.last_seen > self.stream_ttl] | |
| for key in stale_keys: | |
| self.streams.pop(key, None) | |
| if len(self.streams) > self.max_streams: | |
| for key, _ in sorted(self.streams.items(), key=lambda item: item[1].last_seen)[:len(self.streams) - self.max_streams]: | |
| self.streams.pop(key, None) | |
| def add_tcp_segment(self, src_ip: str, src_port: int, dst_ip: str, dst_port: int, seq: int, payload: bytes) -> Dict[str, Any]: | |
| key, forward_is_client = canonical_flow_key(src_ip, src_port, dst_ip, dst_port) | |
| stream = self.streams.setdefault(key, StreamState()) | |
| stream.last_seen = time.time() | |
| state = stream.client_to_server if forward_is_client else stream.server_to_client | |
| added = state.add(seq, payload) | |
| self._evict_old() | |
| return { | |
| "flow_key": key, | |
| "is_client_to_server": forward_is_client, | |
| "bytes_added": added, | |
| "client_stream": bytes(stream.client_to_server.contiguous), | |
| "server_stream": bytes(stream.server_to_client.contiguous), | |
| } | |
| class ProtocolDecoder: | |
| def __init__(self, logger: Optional[logging.Logger] = None) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| def decode_http(self, payload: bytes, direction: str) -> Dict[str, Any]: | |
| text = payload.decode("iso-8859-1", errors="replace") | |
| head, _, body = text.partition("\r\n\r\n") | |
| lines = head.split("\r\n") if head else [] | |
| headers: Dict[str, str] = {} | |
| for line in lines[1:]: | |
| if ":" in line: | |
| k, v = line.split(":", 1) | |
| headers[k.strip().lower()] = v.strip() | |
| normalized = { | |
| "protocol": "http1", | |
| "direction": direction, | |
| "start_line": lines[0] if lines else "", | |
| "headers": headers, | |
| "body_bytes": body.encode("iso-8859-1", errors="ignore"), | |
| "host": headers.get("host", ""), | |
| "content_type": headers.get("content-type", ""), | |
| "transfer_encoding": headers.get("transfer-encoding", "").lower(), | |
| "content_encoding": headers.get("content-encoding", "").lower(), | |
| } | |
| return normalized | |
| def decode_dns(self, payload: bytes) -> Dict[str, Any]: | |
| if len(payload) < 12: | |
| return {"protocol": "dns", "error": "short_payload"} | |
| tid, flags, qdcount, ancount, nscount, arcount = struct.unpack("!HHHHHH", payload[:12]) | |
| qr = (flags >> 15) & 0x1 | |
| opcode = (flags >> 11) & 0xF | |
| rcode = flags & 0xF | |
| query_name = "" | |
| if qdcount > 0: | |
| try: | |
| query_name, _ = DNSParser(logger=self.logger).parse_dns_name(payload, 12) | |
| except Exception: | |
| query_name = "" | |
| return { | |
| "protocol": "dns", | |
| "transaction_id": tid, | |
| "is_response": bool(qr), | |
| "opcode": opcode, | |
| "rcode": rcode, | |
| "qdcount": qdcount, | |
| "ancount": ancount, | |
| "nscount": nscount, | |
| "arcount": arcount, | |
| "query": query_name, | |
| } | |
| def _parse_tls_client_hello(self, payload: bytes) -> Dict[str, Any]: | |
| if len(payload) < 5 or payload[0] != 0x16: | |
| return {} | |
| rec_len = int.from_bytes(payload[3:5], "big") | |
| rec = payload[5:5 + rec_len] | |
| if len(rec) < 4 or rec[0] != 0x01: | |
| return {} | |
| hs_len = int.from_bytes(rec[1:4], "big") | |
| body = rec[4:4 + hs_len] | |
| if len(body) < 34: | |
| return {} | |
| idx = 34 | |
| if idx >= len(body): | |
| return {} | |
| sess_len = body[idx] | |
| idx += 1 + sess_len | |
| if idx + 2 > len(body): | |
| return {} | |
| ciphers_len = int.from_bytes(body[idx:idx + 2], "big") | |
| idx += 2 | |
| ciphers_raw = body[idx:idx + ciphers_len] | |
| idx += ciphers_len | |
| if idx >= len(body): | |
| return {} | |
| comp_len = body[idx] | |
| idx += 1 + comp_len | |
| exts: List[int] = [] | |
| curves: List[int] = [] | |
| ec_pf: List[int] = [] | |
| sni = "" | |
| if idx + 2 <= len(body): | |
| ext_total = int.from_bytes(body[idx:idx + 2], "big") | |
| idx += 2 | |
| end = min(len(body), idx + ext_total) | |
| while idx + 4 <= end: | |
| etype = int.from_bytes(body[idx:idx + 2], "big") | |
| elen = int.from_bytes(body[idx + 2:idx + 4], "big") | |
| idx += 4 | |
| eval_data = body[idx:idx + elen] | |
| idx += elen | |
| exts.append(etype) | |
| if etype == 0 and len(eval_data) >= 5: | |
| list_len = int.from_bytes(eval_data[0:2], "big") | |
| if 2 + list_len <= len(eval_data) and eval_data[2] == 0: | |
| name_len = int.from_bytes(eval_data[3:5], "big") | |
| if 5 + name_len <= len(eval_data): | |
| sni = eval_data[5:5 + name_len].decode("utf-8", errors="replace") | |
| elif etype == 10 and len(eval_data) >= 2: | |
| glen = int.from_bytes(eval_data[0:2], "big") | |
| data = eval_data[2:2 + glen] | |
| curves = [int.from_bytes(data[i:i + 2], "big") for i in range(0, len(data), 2) if i + 2 <= len(data)] | |
| elif etype == 11 and len(eval_data) >= 1: | |
| flen = eval_data[0] | |
| data = eval_data[1:1 + flen] | |
| ec_pf = [b for b in data] | |
| ciphers = [int.from_bytes(ciphers_raw[i:i + 2], "big") for i in range(0, len(ciphers_raw), 2) if i + 2 <= len(ciphers_raw)] | |
| ja3_string = ",".join([ | |
| "771", | |
| "-".join(str(c) for c in ciphers), | |
| "-".join(str(e) for e in exts), | |
| "-".join(str(c) for c in curves), | |
| "-".join(str(p) for p in ec_pf), | |
| ]) | |
| ja3_hash = hashlib.md5(ja3_string.encode("utf-8", errors="ignore")).hexdigest() | |
| return { | |
| "protocol": "tls", | |
| "record_type": "handshake", | |
| "sni": sni, | |
| "ja3": ja3_string, | |
| "ja3_hash": ja3_hash, | |
| } | |
| def decode_smtp(self, payload: bytes) -> Dict[str, Any]: | |
| text = payload.decode("utf-8", errors="replace") | |
| lines = [line.strip() for line in text.splitlines() if line.strip()] | |
| commands = [line.split(" ", 1)[0].upper() for line in lines[:25]] | |
| return {"protocol": "smtp", "commands": commands, "line_count": len(lines)} | |
| def decode_smb(self, payload: bytes) -> Dict[str, Any]: | |
| if len(payload) >= 4 and payload[:4] in (b"\xfeSMB", b"\xffSMB"): | |
| return {"protocol": "smb", "signature": payload[:4].hex(), "command": payload[4] if len(payload) > 4 else None} | |
| return {"protocol": "smb", "error": "not_smb"} | |
| def decode_by_ports_or_signature(self, src_port: int, dst_port: int, payload: bytes, direction: str) -> Dict[str, Any]: | |
| ports = {src_port, dst_port} | |
| if not payload: | |
| return {"protocol": "unknown"} | |
| if 53 in ports: | |
| return self.decode_dns(payload) | |
| if payload.startswith((b"GET ", b"POST ", b"PUT ", b"DELETE ", b"HEAD ", b"OPTIONS ", b"HTTP/1.")) or 80 in ports or 8080 in ports: | |
| return self.decode_http(payload, direction) | |
| if payload[:1] == b"\x16" or 443 in ports: | |
| tls_data = self._parse_tls_client_hello(payload) | |
| return tls_data if tls_data else {"protocol": "tls", "record_type": "unknown"} | |
| if 25 in ports or payload[:5].upper() in (b"HELO ", b"EHLO ", b"MAIL ", b"RCPT ", b"DATA\r"): | |
| return self.decode_smtp(payload) | |
| if 445 in ports or payload.startswith((b"\xfeSMB", b"\xffSMB")): | |
| return self.decode_smb(payload) | |
| return {"protocol": "unknown"} | |
| class ContentDecoder: | |
| def __init__(self, logger: Optional[logging.Logger] = None) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| def decode_chunked(self, body: bytes) -> bytes: | |
| out = bytearray() | |
| idx = 0 | |
| while idx < len(body): | |
| end = body.find(b"\r\n", idx) | |
| if end == -1: | |
| break | |
| size_line = body[idx:end].split(b";", 1)[0].strip() | |
| try: | |
| size = int(size_line, 16) | |
| except ValueError: | |
| break | |
| idx = end + 2 | |
| if size == 0: | |
| break | |
| out.extend(body[idx:idx + size]) | |
| idx += size + 2 | |
| return bytes(out) | |
| def decode_http_content(self, decoded_http: Dict[str, Any]) -> bytes: | |
| body = decoded_http.get("body_bytes", b"") | |
| if decoded_http.get("transfer_encoding") == "chunked": | |
| body = self.decode_chunked(body) | |
| content_encoding = decoded_http.get("content_encoding", "") | |
| try: | |
| if "gzip" in content_encoding: | |
| body = gzip.decompress(body) | |
| elif "deflate" in content_encoding: | |
| body = zlib.decompress(body) | |
| except Exception as exc: | |
| self.logger.debug("Content decoding failed: %s", exc) | |
| return body | |
| def try_base64(self, payload: bytes) -> bytes: | |
| cleaned = b"".join(payload.split()) | |
| if len(cleaned) < 16: | |
| return b"" | |
| if not re.fullmatch(rb"[A-Za-z0-9+/=]+", cleaned): | |
| return b"" | |
| try: | |
| return binascii.a2b_base64(cleaned) | |
| except Exception: | |
| return b"" | |
| def detect_container(self, payload: bytes) -> str: | |
| if payload.startswith(b"PK\x03\x04"): | |
| return "zip" | |
| if payload.startswith(b"%PDF-"): | |
| return "pdf" | |
| if payload.startswith(b"MZ"): | |
| return "pe" | |
| if payload.startswith(bytes.fromhex("D0CF11E0A1B11AE1")): | |
| return "ole" | |
| return "unknown" | |
| class FileObjectCarver: | |
| def __init__(self, logger: Optional[logging.Logger] = None, output_dir: str = "carved_objects") -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.output_dir = output_dir | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| def carve(self, flow_key: FlowKey, payload: bytes, source: str) -> Optional[Dict[str, Any]]: | |
| container = ContentDecoder(logger=self.logger).detect_container(payload) | |
| if container == "unknown" or len(payload) < 64: | |
| return None | |
| timestamp = datetime.utcnow().strftime("%Y%m%dT%H%M%S%f") | |
| filename = f"{flow_key.src_ip}_{flow_key.src_port}_{flow_key.dst_ip}_{flow_key.dst_port}_{source}_{timestamp}.{container}" | |
| safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", filename) | |
| path = os.path.join(self.output_dir, safe_name) | |
| with open(path, "wb") as f: | |
| f.write(payload) | |
| sha256 = hashlib.sha256(payload).hexdigest() | |
| return {"path": path, "sha256": sha256, "container": container, "size": len(payload)} | |
| class SignatureScanner: | |
| def __init__(self, logger: Optional[logging.Logger] = None, yara_rules_path: str = "") -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.yara_rules_path = yara_rules_path | |
| self.yara_rules = None | |
| if yara is not None and yara_rules_path and os.path.exists(yara_rules_path): | |
| try: | |
| self.yara_rules = yara.compile(filepath=yara_rules_path) | |
| self.logger.info("YARA rules loaded: %s", yara_rules_path) | |
| except Exception as exc: | |
| self.logger.warning("Failed to load YARA rules: %s", exc) | |
| def scan_bytes(self, payload: bytes) -> List[str]: | |
| if self.yara_rules is None: | |
| return [] | |
| try: | |
| matches = self.yara_rules.match(data=payload) | |
| return [m.rule for m in matches] | |
| except Exception as exc: | |
| self.logger.debug("YARA byte scan failed: %s", exc) | |
| return [] | |
| def scan_file(self, path: str) -> Dict[str, Any]: | |
| result = {"yara": [], "av": "not_run"} | |
| if self.yara_rules is not None: | |
| try: | |
| result["yara"] = [m.rule for m in self.yara_rules.match(path)] | |
| except Exception as exc: | |
| self.logger.debug("YARA file scan failed for %s: %s", path, exc) | |
| av_cmd = None | |
| if shutil.which("clamscan"): | |
| av_cmd = ["clamscan", "--no-summary", path] | |
| elif shutil.which("MpCmdRun.exe"): | |
| av_cmd = ["MpCmdRun.exe", "-Scan", "-ScanType", "3", "-File", path] | |
| if av_cmd: | |
| try: | |
| proc = subprocess.run(av_cmd, capture_output=True, text=True, timeout=30) | |
| output = (proc.stdout or "") + "\n" + (proc.stderr or "") | |
| result["av"] = output.strip()[:1000] | |
| except Exception as exc: | |
| result["av"] = f"scan_failed: {exc}" | |
| return result | |
| class HeuristicDetector: | |
| def __init__(self, logger: Optional[logging.Logger] = None) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.conn_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=128)) | |
| self.dns_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=128)) | |
| @staticmethod | |
| def entropy(payload: bytes) -> float: | |
| if not payload: | |
| return 0.0 | |
| counts = Counter(payload) | |
| total = len(payload) | |
| return -sum((c / total) * np.log2(c / total) for c in counts.values() if c > 0) if np is not None else 0.0 | |
| def score_dga_like(self, domain: str) -> float: | |
| if not domain: | |
| return 0.0 | |
| label = domain.split(".")[0].lower() | |
| if len(label) < 12: | |
| return 0.0 | |
| consonants = sum(ch in "bcdfghjklmnpqrstvwxyz" for ch in label) | |
| vowels = sum(ch in "aeiou" for ch in label) | |
| digit_ratio = sum(ch.isdigit() for ch in label) / max(len(label), 1) | |
| weird_ratio = sum(ch not in string.ascii_lowercase + string.digits + "-" for ch in label) / max(len(label), 1) | |
| imbalance = abs(consonants - vowels) / max(len(label), 1) | |
| return min(1.0, digit_ratio * 0.6 + weird_ratio * 0.8 + imbalance) | |
| def track_connect(self, src_ip: str, dst_ip: str) -> float: | |
| now = time.time() | |
| history = self.conn_history[src_ip] | |
| history.append((now, dst_ip)) | |
| recent = [entry for entry in history if now - entry[0] <= 60] | |
| unique_targets = len({entry[1] for entry in recent}) | |
| return min(1.0, unique_targets / 40.0) | |
| def track_dns(self, src_ip: str, domain: str) -> float: | |
| now = time.time() | |
| history = self.dns_history[src_ip] | |
| history.append((now, domain)) | |
| recent = [entry for entry in history if now - entry[0] <= 60] | |
| return min(1.0, len(recent) / 80.0) | |
| class SessionCorrelator: | |
| def __init__(self, logger: Optional[logging.Logger] = None) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.events: Dict[str, deque] = defaultdict(lambda: deque(maxlen=256)) | |
| def add_event(self, host: str, kind: str, details: Dict[str, Any]) -> None: | |
| self.events[host].append({"ts": time.time(), "kind": kind, "details": details}) | |
| def detect_kill_chain(self, host: str) -> bool: | |
| ev = list(self.events.get(host, [])) | |
| if len(ev) < 3: | |
| return False | |
| kinds = [e["kind"] for e in ev[-20:]] | |
| def idx(name: str) -> int: | |
| try: | |
| return kinds.index(name) | |
| except ValueError: | |
| return -1 | |
| i_dns = idx("dns_query") | |
| i_conn = idx("connect") | |
| i_download = idx("download") | |
| i_exec = idx("execute") | |
| return -1 not in (i_dns, i_conn, i_download, i_exec) and i_dns < i_conn < i_download < i_exec | |
| class MLClassifier: | |
| def __init__(self, model_path: str = "model.pkl", logger: Optional[logging.Logger] = None) -> None: | |
| self.logger = logger or logging.getLogger(__name__) | |
| try: | |
| import joblib | |
| self.model = joblib.load(model_path) | |
| self.logger.info("ML model loaded successfully from %s", model_path) | |
| except Exception as e: | |
| self.logger.error("Failed to load ML model from %s: %s", model_path, e) | |
| self.model = None | |
| def extract_features(self, payload: bytes) -> np.ndarray: | |
| length = len(payload) | |
| counts = Counter(payload) | |
| total = length if length > 0 else 1 | |
| entropy = -sum((count / total) * np.log2(count / total) for count in counts.values() if count > 0) | |
| return np.array([[length, entropy]]) | |
| def classify(self, payload: bytes) -> bool: | |
| features = self.extract_features(payload) | |
| if self.model is not None: | |
| prediction = self.model.predict(features) | |
| self.logger.debug("ML model prediction: %s", prediction) | |
| return bool(prediction[0]) | |
| else: | |
| if features[0, 0] > 1000 and features[0, 1] > 7.0: | |
| self.logger.debug("Fallback heuristic: payload marked as malicious.") | |
| return True | |
| return False | |
| @staticmethod | |
| def train_model(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None: | |
| import pandas as pd | |
| from sklearn.ensemble import RandomForestClassifier | |
| import joblib | |
| logger = logger or logging.getLogger(__name__) | |
| logger.info("Loading dataset from %s", dataset_path) | |
| try: | |
| df = pd.read_csv(dataset_path) | |
| except Exception as e: | |
| logger.error("Failed to load dataset: %s", e) | |
| return | |
| features = [] | |
| labels = [] | |
| classifier = MLClassifier(logger=logger) | |
| for index, row in df.iterrows(): | |
| payload_str = row.get('payload', '') | |
| try: | |
| payload_bytes = bytes.fromhex(payload_str) | |
| except Exception as e: | |
| logger.error("Error converting payload to bytes for row %d: %s", index, e) | |
| continue | |
| feats = classifier.extract_features(payload_bytes)[0] | |
| features.append(feats) | |
| labels.append(row.get('label', 0)) | |
| if not features: | |
| logger.error("No valid training samples found.") | |
| return | |
| logger.info("Training model on %d samples", len(features)) | |
| clf = RandomForestClassifier(n_estimators=100, random_state=42) | |
| clf.fit(features, labels) | |
| joblib.dump(clf, model_output_path) | |
| logger.info("Model trained and saved to %s", model_output_path) | |
| def get_local_process_info(port: int) -> str: | |
| try: | |
| import psutil | |
| for conn in psutil.net_connections(kind="inet"): | |
| if conn.laddr and conn.laddr.port == port: | |
| return str(conn.pid) if conn.pid is not None else "N/A" | |
| except ImportError: | |
| return "N/A (psutil not installed)" | |
| return "N/A" | |
| class XDPCollector: | |
| def __init__(self, interface: str, logger: Optional[logging.Logger] = None) -> None: | |
| if BPF is None: | |
| raise ImportError("BCC BPF module is not available. Please install bcc.") | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.interface = interface | |
| self.bpf = BPF(text=self._bpf_program()) | |
| func = self.bpf.load_func("xdp_prog", BPF.XDP) | |
| self.bpf.attach_xdp(self.interface, func, 0) | |
| self.logger.info("XDP program attached on interface %s", self.interface) | |
| self.bpf["events"].open_perf_buffer(self._handle_event) | |
| def _bpf_program(self) -> str: | |
| return """ | |
| #include <uapi/linux/bpf.h> | |
| #include <linux/if_ether.h> | |
| #include <linux/ip.h> | |
| struct data_t { | |
| u32 pkt_len; | |
| u32 src_ip; | |
| u32 dst_ip; | |
| u8 protocol; | |
| }; | |
| BPF_PERF_OUTPUT(events); | |
| int xdp_prog(struct xdp_md *ctx) { | |
| struct data_t data = {}; | |
| void *data_end = (void *)(long)ctx->data_end; | |
| void *data_ptr = (void *)(long)ctx->data; | |
| struct ethhdr *eth = data_ptr; | |
| if (data_ptr + sizeof(*eth) > data_end) | |
| return XDP_PASS; | |
| if (eth->h_proto == __constant_htons(ETH_P_IP)) { | |
| struct iphdr *ip = data_ptr + sizeof(*eth); | |
| if ((void*)ip + sizeof(*ip) > data_end) | |
| return XDP_PASS; | |
| data.pkt_len = data_end - data_ptr; | |
| data.src_ip = ip->saddr; | |
| data.dst_ip = ip->daddr; | |
| data.protocol = ip->protocol; | |
| events.perf_submit(ctx, &data, sizeof(data)); | |
| } | |
| return XDP_PASS; | |
| } | |
| """ | |
| def _handle_event(self, cpu, data, size): | |
| global last_xdp_metadata | |
| event = self.bpf["events"].event(data) | |
| last_xdp_metadata = { | |
| "pkt_len": event.pkt_len, | |
| "src_ip": event.src_ip, | |
| "dst_ip": event.dst_ip, | |
| "protocol": event.protocol | |
| } | |
| self.logger.debug("XDP event: %s", last_xdp_metadata) | |
| def poll(self): | |
| self.bpf.perf_buffer_poll(timeout=100) | |
| def detach(self): | |
| self.bpf.remove_xdp(self.interface, 0) | |
| def get_xdp_metadata_details() -> str: | |
| global last_xdp_metadata | |
| if last_xdp_metadata: | |
| return (f"XDP Metadata: Packet Length: {last_xdp_metadata.get('pkt_len')}, " | |
| f"Src IP: {last_xdp_metadata.get('src_ip')}, " | |
| f"Dst IP: {last_xdp_metadata.get('dst_ip')}, " | |
| f"Protocol: {last_xdp_metadata.get('protocol')}") | |
| else: | |
| return "No XDP metadata available." | |
| class PacketSniffer: | |
| COMMON_PORTS = { | |
| 20: ("FTP Data", "File Transfer Protocol - Data channel"), | |
| 21: ("FTP Control", "File Transfer Protocol - Control channel"), | |
| 22: ("SSH", "Secure Shell"), | |
| 23: ("Telnet", "Telnet protocol"), | |
| 25: ("SMTP", "Simple Mail Transfer Protocol"), | |
| 53: ("DNS", "Domain Name System"), | |
| 67: ("DHCP", "DHCP Server"), | |
| 68: ("DHCP", "DHCP Client"), | |
| 80: ("HTTP", "Hypertext Transfer Protocol"), | |
| 110: ("POP3", "Post Office Protocol"), | |
| 119: ("NNTP", "Network News Transfer Protocol"), | |
| 123: ("NTP", "Network Time Protocol"), | |
| 143: ("IMAP", "Internet Message Access Protocol"), | |
| 161: ("SNMP", "Simple Network Management Protocol"), | |
| 443: ("HTTPS", "HTTP Secure"), | |
| 3306: ("MySQL", "MySQL database service"), | |
| 5432: ("PostgreSQL", "PostgreSQL database service"), | |
| 3389: ("RDP", "Remote Desktop Protocol") | |
| } | |
| def __init__(self, args: argparse.Namespace) -> None: | |
| self.args = args | |
| self.logger = logging.getLogger("PacketSniffer") | |
| self.flagged_logger = logging.getLogger("FlaggedIPLogger") | |
| self.dns_logger = logging.getLogger("DNSQueryLogger") | |
| self.ip_lookup = IPLookup(ttl=3600, logger=self.logger) | |
| self.dns_parser = DNSParser(logger=self.logger) | |
| self.payload_parser = PayloadParser(logger=self.logger) | |
| self.protocol_decoder = ProtocolDecoder(logger=self.logger) | |
| self.content_decoder = ContentDecoder(logger=self.logger) | |
| self.file_carver = FileObjectCarver(logger=self.logger, output_dir=getattr(self.args, "carve_dir", "carved_objects")) | |
| self.signature_scanner = SignatureScanner(logger=self.logger, yara_rules_path=getattr(self.args, "yara_rules", "")) | |
| self.heuristics = HeuristicDetector(logger=self.logger) | |
| self.correlator = SessionCorrelator(logger=self.logger) | |
| self.reassembler = TCPStreamReassembler(logger=self.logger) | |
| self.async_runner = AsyncRunner() | |
| if self.args.sentry: | |
| self.ml_classifier = MLClassifier(model_path=self.args.model_path, logger=self.logger) | |
| self.xdp_collector = None | |
| if self.args.xdp and BPF is not None: | |
| try: | |
| self.xdp_collector = XDPCollector(self.args.interface, logger=self.logger) | |
| self.xdp_thread = threading.Thread(target=self._poll_xdp, daemon=True) | |
| self.xdp_thread.start() | |
| except Exception as e: | |
| self.logger.exception("Failed to initialize XDPCollector: %s", e) | |
| def _risk_from_signals(self, signature_hits: List[str], heuristic_score: float, ml_flag: bool) -> float: | |
| score = 0.0 | |
| score += min(0.6, 0.2 * len(signature_hits)) | |
| score += min(0.3, heuristic_score * 0.3) | |
| if ml_flag: | |
| score += 0.3 | |
| return min(1.0, score) | |
| def _emit_session_events(self, src_ip: str, dst_ip: str, decoded: Dict[str, Any], carved: Optional[Dict[str, Any]]) -> None: | |
| protocol = decoded.get("protocol", "unknown") | |
| self.correlator.add_event(src_ip, "connect", {"dst_ip": dst_ip, "protocol": protocol}) | |
| if protocol == "dns" and decoded.get("query"): | |
| self.correlator.add_event(src_ip, "dns_query", {"query": decoded.get("query")}) | |
| if carved is not None: | |
| self.correlator.add_event(src_ip, "download", {"sha256": carved.get("sha256"), "path": carved.get("path")}) | |
| if carved.get("container") in ("pe",): | |
| self.correlator.add_event(src_ip, "execute", {"container": carved.get("container")}) | |
| def _analyze_reassembled_tcp(self, src_ip: str, src_port: int, dst_ip: str, dst_port: int, tcp_seq: int, payload: bytes) -> None: | |
| if not payload: | |
| return | |
| reassembled = self.reassembler.add_tcp_segment(src_ip, src_port, dst_ip, dst_port, tcp_seq, payload) | |
| if reassembled.get("bytes_added", 0) == 0: | |
| return | |
| flow_key: FlowKey = reassembled["flow_key"] | |
| direction = "c2s" if reassembled.get("is_client_to_server") else "s2c" | |
| stream_payload = reassembled["client_stream"] if direction == "c2s" else reassembled["server_stream"] | |
| tail = stream_payload[-65536:] | |
| decoded = self.protocol_decoder.decode_by_ports_or_signature(src_port, dst_port, tail, direction) | |
| protocol = decoded.get("protocol", "unknown") | |
| signature_hits: List[str] = [] | |
| candidate_payload = tail | |
| if protocol == "http1": | |
| candidate_payload = self.content_decoder.decode_http_content(decoded) | |
| else: | |
| b64 = self.content_decoder.try_base64(tail) | |
| if b64: | |
| candidate_payload = b64 | |
| signature_hits.extend(self.signature_scanner.scan_bytes(candidate_payload)) | |
| carved = self.file_carver.carve(flow_key, candidate_payload, source=direction) | |
| file_scan: Dict[str, Any] = {"yara": [], "av": "not_run"} | |
| if carved is not None: | |
| file_scan = self.signature_scanner.scan_file(carved["path"]) | |
| signature_hits.extend(file_scan.get("yara", [])) | |
| entropy_score = 1.0 if self.heuristics.entropy(candidate_payload[:4096]) >= 7.3 else 0.0 | |
| dga_score = 0.0 | |
| if protocol == "dns": | |
| dga_score = self.heuristics.score_dga_like(decoded.get("query", "")) | |
| beacon_score = self.heuristics.track_connect(src_ip, dst_ip) | |
| if protocol == "dns": | |
| _ = self.heuristics.track_dns(src_ip, decoded.get("query", "")) | |
| heuristic_score = min(1.0, max(entropy_score, dga_score, beacon_score)) | |
| ml_flag = False | |
| if self.args.sentry: | |
| ml_flag = self.ml_classifier.classify(candidate_payload) | |
| risk = self._risk_from_signals(signature_hits, heuristic_score, ml_flag) | |
| self._emit_session_events(src_ip, dst_ip, decoded, carved) | |
| kill_chain = self.correlator.detect_kill_chain(src_ip) | |
| tls_meta = {} | |
| if protocol == "tls": | |
| tls_meta = { | |
| "sni": decoded.get("sni", ""), | |
| "ja3": decoded.get("ja3", ""), | |
| "ja3_hash": decoded.get("ja3_hash", ""), | |
| } | |
| if risk >= 0.7 or kill_chain: | |
| message = ( | |
| "\n====== ADVANCED DETECTION ALERT ======\n" | |
| f"Flow: {src_ip}:{src_port} -> {dst_ip}:{dst_port}\n" | |
| f"Protocol: {protocol}\n" | |
| f"Risk Score: {risk:.2f}\n" | |
| f"Signatures: {signature_hits}\n" | |
| f"Heuristic Score: {heuristic_score:.2f}\n" | |
| f"ML Flag: {ml_flag}\n" | |
| f"TLS Metadata: {tls_meta}\n" | |
| f"Carved Object: {carved}\n" | |
| f"File Scan: {file_scan}\n" | |
| f"Kill Chain Detected: {kill_chain}\n" | |
| "======================================\n" | |
| ) | |
| self.flagged_logger.warning(message) | |
| if self.args.sentry: | |
| self.block_ip(src_ip) | |
| def _poll_xdp(self): | |
| while True: | |
| try: | |
| if self.xdp_collector: | |
| self.xdp_collector.poll() | |
| except Exception as e: | |
| self.logger.exception("Error polling XDP events: %s", e) | |
| def identify_application(self, src_port: int, dst_port: int) -> Tuple[str, str]: | |
| for port in (dst_port, src_port): | |
| if port in self.COMMON_PORTS: | |
| return self.COMMON_PORTS[port] | |
| return ("Unknown", "Unknown application protocol") | |
| def block_ip(self, ip: str) -> None: | |
| cmd = f"echo 'Blocking IP {ip}' >> BLOCKER.GARY" | |
| self.logger.info("Blocking IP %s with command: %s", ip, cmd) | |
| os.system(cmd) | |
| def _is_noisy_discovery_traffic(self, packet) -> bool: | |
| if not packet.haslayer(IP): | |
| return False | |
| src_ip = packet[IP].src | |
| dst_ip = packet[IP].dst | |
| try: | |
| src_obj = ipaddress.ip_address(src_ip) | |
| dst_obj = ipaddress.ip_address(dst_ip) | |
| except ValueError: | |
| return False | |
| if not packet.haslayer(UDP): | |
| return False | |
| return src_obj.is_private and (dst_obj.is_multicast or dst_ip == "255.255.255.255") | |
| def log_flagged_ip(self, packet, flagged_signatures: List[Tuple[str, str]], | |
| app_name: str, app_details: str) -> None: | |
| source_ip = "Unknown" | |
| dest_ip = "Unknown" | |
| port_info = "" | |
| process_id = "N/A" | |
| if packet.haslayer(IP): | |
| source_ip = packet[IP].src | |
| dest_ip = packet[IP].dst | |
| elif packet.haslayer(IPv6): | |
| source_ip = packet[IPv6].src | |
| dest_ip = packet[IPv6].dst | |
| if packet.haslayer(TCP): | |
| tcp_layer = packet[TCP] | |
| port_info = f"TCP src: {tcp_layer.sport}, dst: {tcp_layer.dport}" | |
| process_id = get_local_process_info(tcp_layer.dport) | |
| elif packet.haslayer(UDP): | |
| udp_layer = packet[UDP] | |
| port_info = f"UDP src: {udp_layer.sport}, dst: {udp_layer.dport}" | |
| process_id = get_local_process_info(udp_layer.dport) | |
| ip_background = "No background info available." | |
| future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(source_ip)) | |
| try: | |
| info = future.result(timeout=6) | |
| if info: | |
| ip_background = "\n".join([f"{k.capitalize()}: {v}" for k, v in info.items() if k in ("hostname", "city", "region", "country", "org")]) | |
| except Exception as e: | |
| self.logger.exception("Error retrieving IP background info: %s", e) | |
| message = ( | |
| "\n====== FLAGGED IP ALERT ======\n" | |
| f"Source IP: {source_ip}\n" | |
| f"Destination IP: {dest_ip}\n" | |
| f"Application: {app_name} ({app_details})\n" | |
| f"Port Info: {port_info}\n" | |
| f"Process ID: {process_id}\n" | |
| f"IP Background:\n{ip_background}\n" | |
| f"Flagged Signatures: {flagged_signatures}\n" | |
| "===============================\n" | |
| ) | |
| self.flagged_logger.warning(message) | |
| def parse_payload(self, packet, app_name: str, payload: bytes) -> None: | |
| if self.args.sentry: | |
| if self.ml_classifier.classify(payload): | |
| src_ip = packet[IP].src if packet.haslayer(IP) else "unknown" | |
| self.logger.warning("Sentry mode: payload classified as malicious. Blocking IP %s", src_ip) | |
| self.block_ip(src_ip) | |
| return | |
| self.logger.info("Parsing payload for application: %s", app_name) | |
| flagged_signatures = self.payload_parser.analyze_hex_dump(payload) | |
| if flagged_signatures and not self._is_noisy_discovery_traffic(packet): | |
| app_info = self.identify_application( | |
| packet[TCP].sport if packet.haslayer(TCP) else (packet[UDP].sport if packet.haslayer(UDP) else 0), | |
| packet[TCP].dport if packet.haslayer(TCP) else (packet[UDP].dport if packet.haslayer(UDP) else 0) | |
| ) | |
| self.log_flagged_ip(packet, flagged_signatures, app_name, app_info[1]) | |
| if app_name == "HTTP": | |
| self.payload_parser.parse_http_payload(payload) | |
| elif app_name == "HTTPS": | |
| self.payload_parser.parse_tls_payload(payload[:64]) | |
| elif app_name == "DNS": | |
| self.dns_parser.parse_dns_payload(payload) | |
| else: | |
| self.payload_parser.parse_text_payload(payload) | |
| def packet_handler(self, packet) -> None: | |
| if self.args.expensive and packet.haslayer(Raw): | |
| payload = bytes(packet[Raw].load) | |
| normalized_payload = self.payload_parser.normalize_payload(payload) | |
| hex_payload = payload.hex() | |
| packet_summary = packet.summary() | |
| packet_details = packet.show(dump=True) | |
| self.logger.info("Expensive mode: analyzing packet with LLM.") | |
| label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=self.logger) | |
| if label == 1: | |
| self.logger.warning("Expensive Mode: Packet flagged as malicious by LLM.") | |
| self.log_flagged_ip(packet, flagged_signatures=[], app_name="Expensive Mode", app_details="LLM flagged malicious") | |
| src_ip = packet[IP].src if packet.haslayer(IP) else "unknown" | |
| self.block_ip(src_ip) | |
| else: | |
| self.logger.info("Expensive Mode: Packet deemed benign by LLM.") | |
| return | |
| ip_str = None | |
| if packet.haslayer(IP): | |
| ip_str = packet[IP].src | |
| ip_obj = ipaddress.ip_address(ip_str) | |
| if self.args.local_only and not ip_obj.is_private: | |
| return | |
| if not self.args.local_only and ip_obj.is_private: | |
| return | |
| elif packet.haslayer(IPv6): | |
| ip_str = packet[IPv6].src | |
| ip_obj = ipaddress.ip_address(ip_str) | |
| if self.args.local_only and not ip_obj.is_private: | |
| return | |
| if not self.args.local_only and ip_obj.is_private: | |
| return | |
| else: | |
| self.logger.warning("Packet without IP/IPv6 layer") | |
| return | |
| if not self.args.local_only and self.args.red_alert: | |
| future = self.async_runner.run_coroutine(self.ip_lookup.get_ip_info(ip_str)) | |
| try: | |
| ip_info = future.result(timeout=6) | |
| country = ip_info.get("country", "").upper() | |
| if country in self.args.whitelist: | |
| self.logger.info("Skipping packet from whitelisted country: %s", country) | |
| return | |
| except Exception as e: | |
| self.logger.exception("Error during red alert IP filtering: %s", e) | |
| self.logger.info("=" * 80) | |
| self.logger.info("Packet: %s", packet.summary()) | |
| summary_str = packet.summary() | |
| dns_match = re.search(r"DNS Qry b'([^']+)'", summary_str) | |
| if dns_match: | |
| dns_query = dns_match.group(1) | |
| if not self.dns_parser.is_blacklisted(dns_query): | |
| self.dns_logger.info("DNS Query (fallback): %s", dns_query) | |
| if packet.haslayer(DNS): | |
| dns_layer = packet[DNS] | |
| if dns_layer.qr == 0 and dns_layer.qd is not None: | |
| try: | |
| if isinstance(dns_layer.qd, DNSQR): | |
| dns_query = (dns_layer.qd.qname.decode() | |
| if isinstance(dns_layer.qd.qname, bytes) | |
| else dns_layer.qd.qname) | |
| else: | |
| dns_query = ", ".join( | |
| q.qname.decode() if isinstance(q.qname, bytes) else q.qname | |
| for q in dns_layer.qd | |
| ) | |
| except Exception as e: | |
| dns_query = str(dns_layer.qd) | |
| self.logger.info("DNS Query (from DNS layer): %s", dns_query) | |
| try: | |
| raw_dns_payload = bytes(dns_layer) | |
| self.dns_parser.parse_dns_payload(raw_dns_payload) | |
| except Exception as e: | |
| self.logger.exception("Error processing DNS layer: %s", e) | |
| if packet.haslayer(Ether): | |
| eth = packet[Ether] | |
| self.logger.info("Ethernet: src=%s, dst=%s, type=0x%04x", eth.src, eth.dst, eth.type) | |
| else: | |
| self.logger.warning("No Ethernet layer found.") | |
| return | |
| if packet.haslayer(ARP): | |
| arp = packet[ARP] | |
| self.logger.info("ARP: op=%s, src=%s, dst=%s", arp.op, arp.psrc, arp.pdst) | |
| return | |
| if packet.haslayer(IP): | |
| ip_layer = packet[IP] | |
| self.logger.info("IPv4: src=%s, dst=%s, ttl=%s, proto=%s", | |
| ip_layer.src, ip_layer.dst, ip_layer.ttl, ip_layer.proto) | |
| elif packet.haslayer(IPv6): | |
| ip_layer = packet[IPv6] | |
| self.logger.info("IPv6: src=%s, dst=%s, hlim=%s", | |
| ip_layer.src, ip_layer.dst, ip_layer.hlim) | |
| if packet.haslayer(TCP): | |
| tcp_layer = packet[TCP] | |
| self.logger.info("TCP: sport=%s, dport=%s", tcp_layer.sport, tcp_layer.dport) | |
| app_name, app_details = self.identify_application(tcp_layer.sport, tcp_layer.dport) | |
| self.logger.info("Identified Application: %s (%s)", app_name, app_details) | |
| if packet.haslayer(Raw): | |
| tcp_payload = bytes(packet[Raw].load) | |
| self._analyze_reassembled_tcp( | |
| src_ip=ip_layer.src, | |
| src_port=tcp_layer.sport, | |
| dst_ip=ip_layer.dst, | |
| dst_port=tcp_layer.dport, | |
| tcp_seq=int(tcp_layer.seq), | |
| payload=tcp_payload, | |
| ) | |
| if app_name == "HTTPS" or (TLS and packet.haslayer(TLS)): | |
| if TLS and packet.haslayer(TLS): | |
| tls_layer = packet[TLS] | |
| self.logger.info("TLS Record: %s", tls_layer.summary()) | |
| if packet.haslayer(TLSClientHello): | |
| client_hello = packet[TLSClientHello] | |
| self.logger.info("TLS ClientHello: %s", client_hello.summary()) | |
| if hasattr(client_hello, 'servernames'): | |
| self.logger.info("SNI: %s", client_hello.servernames) | |
| else: | |
| if packet.haslayer(Raw): | |
| payload = bytes(packet[Raw].load) | |
| self.payload_parser.parse_tls_payload(payload) | |
| keylog = os.getenv("SSLKEYLOGFILE", "") | |
| if keylog: | |
| self.logger.info("TLS plaintext decryption can use SSLKEYLOGFILE at: %s", keylog) | |
| else: | |
| self.logger.info("TLS inspection limited to metadata (SNI/JA3/cert) unless SSLKEYLOGFILE or interception is configured.") | |
| else: | |
| if packet.haslayer(Raw): | |
| payload = bytes(packet[Raw].load) | |
| self.parse_payload(packet, app_name, payload) | |
| elif packet.haslayer(UDP): | |
| udp_layer = packet[UDP] | |
| self.logger.info("UDP: sport=%s, dport=%s", udp_layer.sport, udp_layer.dport) | |
| app_name, app_details = self.identify_application(udp_layer.sport, udp_layer.dport) | |
| self.logger.info("Identified Application: %s (%s)", app_name, app_details) | |
| if packet.haslayer(Raw): | |
| payload = bytes(packet[Raw].load) | |
| if app_name == "DNS": | |
| parsed_dns = self.protocol_decoder.decode_dns(payload) | |
| self.correlator.add_event(ip_layer.src, "dns_query", {"query": parsed_dns.get("query", "")}) | |
| self.parse_payload(packet, app_name, payload) | |
| elif packet.haslayer(ICMP): | |
| icmp_layer = packet[ICMP] | |
| self.logger.info("ICMP: type=%s, code=%s", icmp_layer.type, icmp_layer.code) | |
| else: | |
| self.logger.warning("Unsupported transport layer.") | |
| def run(self) -> None: | |
| self.logger.info("Starting Enhanced Packet Sniffer on interface '%s'", self.args.interface) | |
| backend_ok, backend_msg = check_packet_capture_backend(logger=self.logger) | |
| if not backend_ok: | |
| self.logger.error("%s", backend_msg) | |
| return | |
| try: | |
| bpf_filter = "ip" if not self.args.local_only else "" | |
| sniff(iface=self.args.interface, prn=self.packet_handler, store=0, filter=bpf_filter) | |
| except KeyboardInterrupt: | |
| self.logger.info("Stopping packet capture (KeyboardInterrupt received)...") | |
| except Exception as e: | |
| self.logger.exception("Error during packet capture: %s", e) | |
| async def async_llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int: | |
| """ | |
| Uses the AsyncOpenAI client to determine if a packet is malicious or benign. | |
| Returns 1 if the answer indicates 'malicious', 0 otherwise. | |
| """ | |
| logger = logger or logging.getLogger(__name__) | |
| if AsyncOpenAI is None: | |
| logger.error("AsyncOpenAI client is not available. Please install the openai package (>=1.0.0).") | |
| return 0 | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| logger.error("OpenAI API key not found in environment (OPENAI_API_KEY).") | |
| return 0 | |
| client = AsyncOpenAI(api_key=api_key) | |
| xdp_details = get_xdp_metadata_details() | |
| prompt = ( | |
| f"Examine the following packet details and decide if the packet is malicious or benign.\n\n" | |
| f"Packet Summary:\n{packet_summary}\n\n" | |
| f"Packet Detailed Info (Scapy dump):\n{packet_details}\n\n" | |
| f"Normalized Payload (first 300 characters):\n{normalized_payload[:300]}\n\n" | |
| f"Hex Payload (first 200 characters):\n{hex_payload[:200]}\n\n" | |
| f"{xdp_details}\n\n" | |
| "Answer with a single word: 'malicious' or 'benign'." | |
| ) | |
| logger.debug("LLM Prompt:\n%s", prompt) | |
| try: | |
| response = await client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "You are a network security analyst."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=10, | |
| temperature=0 | |
| ) | |
| answer = response.choices[0].message.content.strip().lower() | |
| logger.debug("LLM response: %s", answer) | |
| return 1 if "malicious" in answer else 0 | |
| except Exception as e: | |
| logger.exception("Error querying OpenAI API: %s", e) | |
| return 0 | |
| def llm_label_packet(packet_summary: str, normalized_payload: str, hex_payload: str, packet_details: str, logger: Optional[logging.Logger] = None) -> int: | |
| """ | |
| Synchronous wrapper for async_llm_label_packet. | |
| """ | |
| logger = logger or logging.getLogger(__name__) | |
| new_loop = asyncio.new_event_loop() | |
| try: | |
| result = new_loop.run_until_complete( | |
| async_llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger) | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.exception("Error in LLM labeling: %s", e) | |
| return 0 | |
| finally: | |
| new_loop.run_until_complete(new_loop.shutdown_asyncgens()) | |
| new_loop.close() | |
| def build_dataset_main(interface: str, num_samples: int, output_path: str, use_llm: bool, logger: Optional[logging.Logger] = None) -> None: | |
| """ | |
| Captures packets with a Raw payload and labels them. | |
| If use_llm is True, the LLM (via the async client) is used for automatic labeling. | |
| The results are saved to a CSV file. | |
| """ | |
| logger = logger or logging.getLogger(__name__) | |
| logger.info("Starting dataset capture on interface %s; capturing %d samples.", interface, num_samples) | |
| samples = [] | |
| def packet_callback(packet): | |
| if packet.haslayer(Raw): | |
| payload = bytes(packet[Raw].load) | |
| hex_payload = payload.hex() | |
| normalized_payload = PayloadParser(logger=logger).normalize_payload(payload) | |
| packet_summary = packet.summary() | |
| packet_details = packet.show(dump=True) | |
| print("\nPacket captured:") | |
| print(packet_summary) | |
| print("Packet Detailed Info:") | |
| print(packet_details) | |
| print("Normalized Payload (first 40 characters):", normalized_payload[:40]) | |
| if use_llm: | |
| label = llm_label_packet(packet_summary, normalized_payload, hex_payload, packet_details, logger=logger) | |
| print(f"LLM labeled this packet as: {'malicious' if label == 1 else 'benign'}") | |
| else: | |
| user_input = input("Label this packet as malicious (1) or benign (0) [default=0]: ").strip() | |
| label = int(user_input) if user_input in ["0", "1"] else 0 | |
| samples.append({"payload": hex_payload, "label": label}) | |
| if len(samples) >= num_samples: | |
| return True | |
| return False | |
| scapy.sniff(iface=interface, prn=packet_callback, store=0, timeout=60) | |
| if samples: | |
| try: | |
| with open(output_path, "w", newline="") as csvfile: | |
| fieldnames = ["payload", "label"] | |
| writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
| writer.writeheader() | |
| for sample in samples: | |
| writer.writerow(sample) | |
| logger.info("Dataset built and saved to %s (%d samples).", output_path, len(samples)) | |
| except Exception as e: | |
| logger.error("Failed to write dataset to %s: %s", output_path, e) | |
| else: | |
| logger.error("No samples were captured.") | |
| def train_model_main(dataset_path: str, model_output_path: str, logger: Optional[logging.Logger] = None) -> None: | |
| logger = logger or logging.getLogger(__name__) | |
| MLClassifier.train_model(dataset_path, model_output_path, logger=logger) | |
| def get_missing_runtime_dependencies() -> List[str]: | |
| missing = [] | |
| dependency_map = [ | |
| ("aiohttp", aiohttp), | |
| ("PyYAML", yaml), | |
| ("cachetools", TTLCache), | |
| ("numpy", np), | |
| ("scapy", scapy), | |
| ] | |
| for package_name, module_ref in dependency_map: | |
| if module_ref is None: | |
| missing.append(package_name) | |
| return missing | |
| def detect_operating_system() -> str: | |
| system_name = platform.system().lower() | |
| if system_name.startswith("win"): | |
| return "windows" | |
| if system_name.startswith("linux"): | |
| return "linux" | |
| if system_name.startswith("darwin"): | |
| return "macos" | |
| return "other" | |
| def get_required_pip_packages(os_name: str) -> List[str]: | |
| base_packages = [ | |
| "aiohttp", | |
| "PyYAML", | |
| "cachetools", | |
| "numpy", | |
| "scapy", | |
| "yara-python", | |
| "openai", | |
| "psutil", | |
| "pandas", | |
| "scikit-learn", | |
| "joblib", | |
| ] | |
| if os_name == "linux": | |
| return [*base_packages, "bcc"] | |
| return base_packages | |
| def install_npcap_windows(logger: Optional[logging.Logger] = None) -> bool: | |
| logger = logger or logging.getLogger(__name__) | |
| backend_ok, _ = check_packet_capture_backend(logger=logger) | |
| if backend_ok: | |
| logger.info("Npcap already available on this Windows machine.") | |
| return True | |
| npcap_url = "https://npcap.com/dist/npcap-1.87.exe" | |
| installer_path = os.path.join(os.environ.get("TEMP", "."), "npcap-1.87.exe") | |
| logger.info("Npcap not detected. Downloading official installer from %s", npcap_url) | |
| try: | |
| from urllib.request import urlretrieve | |
| urlretrieve(npcap_url, installer_path) | |
| except Exception as exc: | |
| logger.error("Failed to download Npcap installer: %s", exc) | |
| return False | |
| logger.info("Launching Npcap installer (may prompt for UAC/admin approval)...") | |
| try: | |
| subprocess.run([installer_path], check=True) | |
| except subprocess.CalledProcessError as exc: | |
| logger.error("Npcap installer exited with code %s", exc.returncode) | |
| return False | |
| except Exception as exc: | |
| logger.error("Failed to launch Npcap installer: %s", exc) | |
| return False | |
| backend_ok, backend_msg = check_packet_capture_backend(logger=logger) | |
| if not backend_ok: | |
| logger.error("Npcap still not available after installer run: %s", backend_msg) | |
| return False | |
| logger.info("Npcap installation verified successfully.") | |
| return True | |
| def install_dependencies(logger: Optional[logging.Logger] = None) -> bool: | |
| logger = logger or logging.getLogger(__name__) | |
| os_name = detect_operating_system() | |
| packages = get_required_pip_packages(os_name) | |
| logger.info("Detected operating system: %s", os_name) | |
| logger.info("Installing OS-specific Python dependencies...") | |
| logger.info("Installing dependencies with pip3...") | |
| pip_install_succeeded = False | |
| try: | |
| subprocess.run(["pip3", "install", *packages], check=True) | |
| logger.info("Dependencies installed successfully via pip3.") | |
| pip_install_succeeded = True | |
| except FileNotFoundError: | |
| logger.warning("pip3 not found. Falling back to python -m pip.") | |
| except subprocess.CalledProcessError as exc: | |
| logger.warning("pip3 install returned an error (%s). Falling back to python -m pip.", exc.returncode) | |
| if not pip_install_succeeded: | |
| try: | |
| subprocess.run([sys.executable, "-m", "pip", "install", *packages], check=True) | |
| logger.info("Dependencies installed successfully via python -m pip.") | |
| pip_install_succeeded = True | |
| except subprocess.CalledProcessError as exc: | |
| logger.error("Dependency installation failed (exit code %s).", exc.returncode) | |
| return False | |
| except Exception as exc: | |
| logger.error("Dependency installation failed: %s", exc) | |
| return False | |
| if not pip_install_succeeded: | |
| logger.error("Dependency installation did not complete successfully.") | |
| return False | |
| if os_name == "windows": | |
| if not install_npcap_windows(logger=logger): | |
| logger.error("Windows install phase incomplete: Npcap setup failed.") | |
| return False | |
| return True | |
| def is_likely_loopback(interface_name: str) -> bool: | |
| lowered = interface_name.lower() | |
| loopback_markers = ["loopback", "lo", "npcap loopback", "software loopback"] | |
| return any(marker in lowered for marker in loopback_markers) | |
| def auto_detect_interface(logger: Optional[logging.Logger] = None) -> str: | |
| logger = logger or logging.getLogger(__name__) | |
| try: | |
| default_iface = str(scapy.conf.iface) | |
| if default_iface and not is_likely_loopback(default_iface): | |
| logger.info("Auto-detected interface from Scapy default: %s", default_iface) | |
| return default_iface | |
| except Exception: | |
| logger.debug("Unable to use Scapy default interface for auto-detection.") | |
| try: | |
| routed_iface = scapy.conf.route.route("8.8.8.8")[0] | |
| if routed_iface and not is_likely_loopback(routed_iface): | |
| logger.info("Auto-detected interface from routing table: %s", routed_iface) | |
| return routed_iface | |
| except Exception: | |
| logger.debug("Unable to use routing table for interface auto-detection.") | |
| try: | |
| interfaces = scapy.get_if_list() | |
| for iface in interfaces: | |
| if iface and not is_likely_loopback(iface): | |
| logger.info("Auto-detected interface from interface list: %s", iface) | |
| return iface | |
| if interfaces: | |
| logger.info("Falling back to first available interface: %s", interfaces[0]) | |
| return interfaces[0] | |
| except Exception as exc: | |
| logger.debug("Failed to list interfaces during auto-detection: %s", exc) | |
| raise RuntimeError("Could not auto-detect a usable network interface.") | |
| def parse_arguments() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Enhanced Packet Sniffer with Red Alert, Whitelist, Sentry Mode, Dataset Building, Training, and LLM-assisted Labeling" | |
| ) | |
| parser.add_argument("-i", "--interface", type=str, default="", | |
| help="Network interface to sniff on") | |
| parser.add_argument("--install", action="store_true", | |
| help="Install required Python dependencies using pip3 and exit") | |
| parser.add_argument("--start", action="store_true", | |
| help="Start sniffing with automatic interface detection when --interface is not provided") | |
| parser.add_argument("-l", "--logfile", type=str, default="sniffer.log", | |
| help="Path to the main log file") | |
| parser.add_argument("--no-bgcheck", action="store_true", | |
| help="Disable IP background lookup") | |
| parser.add_argument("--local-only", action="store_true", | |
| help="Capture only local (private) traffic") | |
| parser.add_argument("--red-alert", action="store_true", | |
| help="Enable red alert mode: only log packets from non-allied countries") | |
| parser.add_argument("--whitelist", type=str, default="US", | |
| help="Comma-separated list of allied (whitelisted) country codes (default: US)") | |
| parser.add_argument("--sentry", action="store_true", | |
| help="Enable sentry mode (ML-based blocking of malicious packets)") | |
| parser.add_argument("--model-path", type=str, default="model.pkl", | |
| help="Path to the ML model file (used in sentry mode)") | |
| parser.add_argument("--train", action="store_true", | |
| help="Run training mode to build a new ML model from a dataset") | |
| parser.add_argument("--dataset", type=str, default="", | |
| help="Path to the CSV dataset for training") | |
| parser.add_argument("--build-dataset", action="store_true", | |
| help="Capture and build a labeled dataset interactively") | |
| parser.add_argument("--num-samples", type=int, default=10, | |
| help="Number of samples to capture when building the dataset") | |
| parser.add_argument("--dataset-out", type=str, default="training_data.csv", | |
| help="Path to save the built dataset CSV file") | |
| parser.add_argument("--llm-label", action="store_true", | |
| help="Use the OpenAI API to automatically label packets in dataset building mode") | |
| parser.add_argument("--xdp", action="store_true", | |
| help="Enable XDP (Express Data Path) to gather additional packet metadata") | |
| parser.add_argument("--expensive", action="store_true", | |
| help="Use LLM to analyze all packets in realtime (expensive mode)") | |
| parser.add_argument("--yara-rules", type=str, default="", | |
| help="Path to YARA rules file for signature scanning") | |
| parser.add_argument("--carve-dir", type=str, default="carved_objects", | |
| help="Directory where carved files/objects are written") | |
| parser.add_argument("-v", "--verbosity", type=str, default="INFO", | |
| help="Logging level (DEBUG, INFO, WARNING, ERROR)") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_arguments() | |
| args.whitelist = set(code.strip().upper() for code in args.whitelist.split(',')) | |
| logging_level = getattr(logging, args.verbosity.upper(), logging.INFO) | |
| setup_logging("logging.yaml") | |
| logging.getLogger().setLevel(logging_level) | |
| logger = logging.getLogger(__name__) | |
| if args.install: | |
| ok = install_dependencies(logger=logger) | |
| raise SystemExit(0 if ok else 1) | |
| missing_runtime_deps = get_missing_runtime_dependencies() | |
| if missing_runtime_deps: | |
| logger.error( | |
| "Missing required dependencies: %s. Run with --install first.", | |
| ", ".join(missing_runtime_deps) | |
| ) | |
| raise SystemExit(1) | |
| if args.build_dataset: | |
| if not args.interface and args.start: | |
| try: | |
| args.interface = auto_detect_interface(logger=logger) | |
| except Exception as exc: | |
| logger.error("--start failed to auto-detect an interface: %s", exc) | |
| raise SystemExit(1) | |
| if not args.interface: | |
| logger.error("Dataset mode requires --interface or --start for auto-detection") | |
| raise SystemExit(1) | |
| build_dataset_main(args.interface, args.num_samples, args.dataset_out, args.llm_label, logger=logger) | |
| elif args.train: | |
| if not args.dataset: | |
| logger.error("Training mode requires a dataset. Please provide --dataset <path>") | |
| else: | |
| train_model_main(args.dataset, args.model_path, logger=logger) | |
| else: | |
| if args.start and not args.interface: | |
| try: | |
| args.interface = auto_detect_interface(logger=logger) | |
| except Exception as exc: | |
| logger.error("--start failed to auto-detect an interface: %s", exc) | |
| raise SystemExit(1) | |
| if not args.interface: | |
| logger.error("Sniffer mode requires --interface or --start for auto-detection") | |
| raise SystemExit(1) | |
| sniffer = PacketSniffer(args) | |
| sniffer.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment