Skip to content

Commit 46b19ad

Browse files
Enhance IP blocking with validation and normalization
Refactor IP blocking logic to support IP normalization and whitelisting. Added validation and improved platform-specific blocking commands.
1 parent c57ab42 commit 46b19ad

File tree

1 file changed

+64
-62
lines changed

1 file changed

+64
-62
lines changed

src/firewall/blocking.py

Lines changed: 64 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,55 @@
88
from colorama import Fore, Style
99
from utils.logger import get_logger
1010
from utils.system import get_platform_firewall_command
11+
import ipaddress
1112

1213

1314
class IPBlocker:
1415
"""Handles IP blocking operations across different platforms"""
15-
16+
1617
def __init__(self, block_duration: int, whitelist: Set[str]):
1718
self.block_duration = block_duration
18-
self.whitelist = whitelist
19+
# Normalize whitelist (support IPv4, IPv6, CIDR)
20+
self.whitelist = {str(ipaddress.ip_network(ip, strict=False)) for ip in whitelist}
1921
self.blocked_ips: Dict[str, datetime] = {}
2022
self.lock = threading.Lock()
2123
self.logger = get_logger(__name__)
2224
self.platform = platform.system().lower()
2325
self.firewall_cmd = get_platform_firewall_command()
24-
26+
27+
def _validate_ip(self, ip: str) -> str:
28+
"""Validate and normalize an IP or network string (IPv4/IPv6/CIDR supported)"""
29+
try:
30+
network = ipaddress.ip_network(ip, strict=False)
31+
return str(network)
32+
except ValueError as e:
33+
self.logger.error(f"Invalid IP/network '{ip}': {e}")
34+
raise ValueError(f"Invalid IP/network: {ip}")
35+
36+
def _is_whitelisted(self, ip: str) -> bool:
37+
"""Check if IP is within whitelist networks"""
38+
try:
39+
ip_net = ipaddress.ip_network(ip, strict=False)
40+
return any(ip_net.subnet_of(ipaddress.ip_network(wl)) for wl in self.whitelist)
41+
except ValueError:
42+
return False
43+
2544
def block_ip(self, ip: str, reason: str) -> bool:
2645
"""Block an IP address using the appropriate firewall system"""
46+
try:
47+
ip = self._validate_ip(ip)
48+
except ValueError:
49+
return False
50+
2751
if self._is_whitelisted(ip):
2852
self.logger.info(f"IP {ip} is whitelisted, not blocking")
2953
return False
30-
54+
3155
with self.lock:
3256
if ip not in self.blocked_ips:
3357
try:
3458
success = self._execute_block_command(ip)
35-
59+
3660
if success:
3761
self.blocked_ips[ip] = datetime.now()
3862
self.logger.warning(f"🚫 BLOCKED IP: {ip} - Reason: {reason}")
@@ -41,21 +65,26 @@ def block_ip(self, ip: str, reason: str) -> bool:
4165
else:
4266
self.logger.error(f"Failed to block IP {ip}")
4367
return False
44-
68+
4569
except Exception as e:
4670
self.logger.error(f"Exception while blocking IP {ip}: {e}")
4771
return False
4872
else:
4973
self.logger.debug(f"IP {ip} already blocked")
5074
return True
51-
75+
5276
def unblock_ip(self, ip: str) -> bool:
5377
"""Manually unblock a specific IP address"""
78+
try:
79+
ip = self._validate_ip(ip)
80+
except ValueError:
81+
return False
82+
5483
with self.lock:
5584
if ip in self.blocked_ips:
5685
try:
5786
success = self._execute_unblock_command(ip)
58-
87+
5988
if success:
6089
del self.blocked_ips[ip]
6190
self.logger.info(f"✅ UNBLOCKED IP: {ip}")
@@ -64,36 +93,32 @@ def unblock_ip(self, ip: str) -> bool:
6493
else:
6594
self.logger.error(f"Failed to unblock IP {ip}")
6695
return False
67-
96+
6897
except Exception as e:
6998
self.logger.error(f"Exception while unblocking IP {ip}: {e}")
7099
return False
71100
else:
72101
self.logger.warning(f"IP {ip} was not blocked")
73102
return False
74-
103+
75104
def unblock_expired_ips(self) -> List[str]:
76105
"""Unblock IPs that have exceeded the block duration"""
77106
current_time = datetime.now()
78107
block_duration = timedelta(seconds=self.block_duration)
79108
unblocked_ips = []
80-
109+
81110
with self.lock:
82111
expired_ips = [
83112
ip for ip, block_time in self.blocked_ips.items()
84113
if current_time - block_time > block_duration
85114
]
86-
115+
87116
for ip in expired_ips:
88117
if self.unblock_ip(ip):
89118
unblocked_ips.append(ip)
90-
119+
91120
return unblocked_ips
92-
93-
def _is_whitelisted(self, ip: str) -> bool:
94-
"""Check if IP is in whitelist"""
95-
return ip in self.whitelist
96-
121+
97122
def _execute_block_command(self, ip: str) -> bool:
98123
"""Execute platform-specific block command"""
99124
try:
@@ -109,7 +134,7 @@ def _execute_block_command(self, ip: str) -> bool:
109134
except Exception as e:
110135
self.logger.error(f"Platform-specific blocking failed: {e}")
111136
return False
112-
137+
113138
def _execute_unblock_command(self, ip: str) -> bool:
114139
"""Execute platform-specific unblock command"""
115140
try:
@@ -125,103 +150,80 @@ def _execute_unblock_command(self, ip: str) -> bool:
125150
except Exception as e:
126151
self.logger.error(f"Platform-specific unblocking failed: {e}")
127152
return False
128-
153+
154+
# -------- PLATFORM-SPECIFIC IMPLEMENTATIONS -------- #
155+
129156
def _block_ip_linux(self, ip: str) -> bool:
130157
"""Block IP using iptables on Linux"""
131158
cmd = ['sudo', 'iptables', '-A', 'INPUT', '-s', ip, '-j', 'DROP']
132-
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
159+
result = subprocess.run(cmd, capture_output=True, text=True)
133160
return result.returncode == 0
134-
161+
135162
def _unblock_ip_linux(self, ip: str) -> bool:
136163
"""Unblock IP using iptables on Linux"""
137164
cmd = ['sudo', 'iptables', '-D', 'INPUT', '-s', ip, '-j', 'DROP']
138165
result = subprocess.run(cmd, capture_output=True, text=True)
139166
return result.returncode == 0
140-
167+
141168
def _block_ip_macos(self, ip: str) -> bool:
142169
"""Block IP using pfctl on macOS"""
143-
# First, add IP to a table
144170
cmd1 = ['sudo', 'pfctl', '-t', 'blocked_ips', '-T', 'add', ip]
145171
result1 = subprocess.run(cmd1, capture_output=True, text=True)
146-
147-
# Then enable the blocking rule (this might need to be done once)
148-
cmd2 = ['sudo', 'pfctl', '-e']
149-
result2 = subprocess.run(cmd2, capture_output=True, text=True)
150-
172+
subprocess.run(['sudo', 'pfctl', '-e'], capture_output=True, text=True)
151173
return result1.returncode == 0
152-
174+
153175
def _unblock_ip_macos(self, ip: str) -> bool:
154176
"""Unblock IP using pfctl on macOS"""
155177
cmd = ['sudo', 'pfctl', '-t', 'blocked_ips', '-T', 'delete', ip]
156178
result = subprocess.run(cmd, capture_output=True, text=True)
157179
return result.returncode == 0
158-
180+
159181
def _block_ip_windows(self, ip: str) -> bool:
160182
"""Block IP using Windows Firewall (netsh)"""
161-
rule_name = f"SimpleFirewall_Block_{ip.replace('.', '_')}"
183+
rule_name = f"SimpleFirewall_Block_{ip.replace(':', '_').replace('.', '_')}"
162184
cmd = [
163185
'netsh', 'advfirewall', 'firewall', 'add', 'rule',
164186
f'name={rule_name}',
165187
'dir=in',
166188
'action=block',
167189
f'remoteip={ip}'
168190
]
169-
try:
170-
result = subprocess.run(cmd, capture_output=True, text=True)
171-
if result.returncode == 0:
172-
self.logger.debug(f"netsh add rule stdout: {result.stdout.strip()}")
173-
return True
174-
else:
175-
self.logger.error(f"netsh add rule failed: rc={result.returncode} stdout={result.stdout.strip()} stderr={result.stderr.strip()}")
176-
return False
177-
except Exception as e:
178-
self.logger.error(f"Exception when running netsh add rule: {e}")
179-
return False
191+
result = subprocess.run(cmd, capture_output=True, text=True)
192+
return result.returncode == 0
180193

181194
def _unblock_ip_windows(self, ip: str) -> bool:
182195
"""Unblock IP using Windows Firewall (netsh)"""
183-
rule_name = f"SimpleFirewall_Block_{ip.replace('.', '_')}"
196+
rule_name = f"SimpleFirewall_Block_{ip.replace(':', '_').replace('.', '_')}"
184197
cmd = [
185198
'netsh', 'advfirewall', 'firewall', 'delete', 'rule',
186199
f'name={rule_name}'
187200
]
188-
try:
189-
result = subprocess.run(cmd, capture_output=True, text=True)
190-
if result.returncode == 0:
191-
self.logger.debug(f"netsh delete rule stdout: {result.stdout.strip()}")
192-
return True
193-
else:
194-
# Sometimes netsh returns 1 when rule not found; log and return False
195-
self.logger.error(f"netsh delete rule failed: rc={result.returncode} stdout={result.stdout.strip()} stderr={result.stderr.strip()}")
196-
return False
197-
except Exception as e:
198-
self.logger.error(f"Exception when running netsh delete rule: {e}")
199-
return False
201+
result = subprocess.run(cmd, capture_output=True, text=True)
202+
return result.returncode == 0
203+
204+
# --------------------------------------------------- #
200205

201-
202206
def get_blocked_ips(self) -> Dict[str, str]:
203207
"""Get currently blocked IPs with their block times"""
204208
with self.lock:
205209
return {
206210
ip: block_time.isoformat()
207211
for ip, block_time in self.blocked_ips.items()
208212
}
209-
213+
210214
def cleanup_all_blocks(self) -> List[str]:
211215
"""Remove all blocks (useful for shutdown)"""
212216
cleaned_ips = []
213-
214217
with self.lock:
215218
for ip in list(self.blocked_ips.keys()):
216219
if self.unblock_ip(ip):
217220
cleaned_ips.append(ip)
218-
219221
return cleaned_ips
220-
222+
221223
def get_stats(self) -> Dict[str, int]:
222224
"""Get blocking statistics"""
223225
with self.lock:
224226
return {
225227
'currently_blocked': len(self.blocked_ips),
226228
'whitelist_size': len(self.whitelist)
227-
}
229+
}

0 commit comments

Comments
 (0)