Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions src/firewall/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from colorama import Fore, Style
from utils.logger import get_logger
from utils.system import get_platform_firewall_command
import ipaddress


class IPBlocker:
Expand All @@ -21,9 +22,27 @@ def __init__(self, block_duration: int, whitelist: Set[str]):
self.logger = get_logger(__name__)
self.platform = platform.system().lower()
self.firewall_cmd = get_platform_firewall_command()


# Added helper for IP validation
def _validate_ip(self, ip: str) -> str:
"""
Validate and normalize an IP or CIDR (IPv4 or IPv6).
Raises ValueError if invalid.
"""
try:
network = ipaddress.ip_network(ip, strict=False)
return str(network)
except ValueError as e:
self.logger.error(f"Invalid IP/network '{ip}': {e}")
raise ValueError(f"Invalid IP/network: {ip}")

def block_ip(self, ip: str, reason: str) -> bool:
"""Block an IP address using the appropriate firewall system"""
try:
ip = self._validate_ip(ip) # Validate IP first
except ValueError:
return False

if self._is_whitelisted(ip):
self.logger.info(f"IP {ip} is whitelisted, not blocking")
return False
Expand Down Expand Up @@ -51,6 +70,11 @@ def block_ip(self, ip: str, reason: str) -> bool:

def unblock_ip(self, ip: str) -> bool:
"""Manually unblock a specific IP address"""
try:
ip = self._validate_ip(ip) # Validate IP before unblocking
except ValueError:
return False

with self.lock:
if ip in self.blocked_ips:
try:
Expand All @@ -71,7 +95,7 @@ def unblock_ip(self, ip: str) -> bool:
else:
self.logger.warning(f"IP {ip} was not blocked")
return False

def unblock_expired_ips(self) -> List[str]:
"""Unblock IPs that have exceeded the block duration"""
current_time = datetime.now()
Expand Down Expand Up @@ -140,14 +164,10 @@ def _unblock_ip_linux(self, ip: str) -> bool:

def _block_ip_macos(self, ip: str) -> bool:
"""Block IP using pfctl on macOS"""
# First, add IP to a table
cmd1 = ['sudo', 'pfctl', '-t', 'blocked_ips', '-T', 'add', ip]
result1 = subprocess.run(cmd1, capture_output=True, text=True)

# Then enable the blocking rule (this might need to be done once)
cmd2 = ['sudo', 'pfctl', '-e']
result2 = subprocess.run(cmd2, capture_output=True, text=True)

subprocess.run(cmd2, capture_output=True, text=True)
return result1.returncode == 0

def _unblock_ip_macos(self, ip: str) -> bool:
Expand Down Expand Up @@ -191,14 +211,12 @@ def _unblock_ip_windows(self, ip: str) -> bool:
self.logger.debug(f"netsh delete rule stdout: {result.stdout.strip()}")
return True
else:
# Sometimes netsh returns 1 when rule not found; log and return False
self.logger.error(f"netsh delete rule failed: rc={result.returncode} stdout={result.stdout.strip()} stderr={result.stderr.strip()}")
return False
except Exception as e:
self.logger.error(f"Exception when running netsh delete rule: {e}")
return False


def get_blocked_ips(self) -> Dict[str, str]:
"""Get currently blocked IPs with their block times"""
with self.lock:
Expand All @@ -224,4 +242,4 @@ def get_stats(self) -> Dict[str, int]:
return {
'currently_blocked': len(self.blocked_ips),
'whitelist_size': len(self.whitelist)
}
}
Loading