Skip to content

Commit d5a37ed

Browse files
Refactor IP validation and whitelist logic in IPBlocker
Refactor IPBlocker to improve IP validation and whitelist handling.
1 parent 46b19ad commit d5a37ed

File tree

1 file changed

+61
-45
lines changed

1 file changed

+61
-45
lines changed

src/firewall/blocking.py

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,50 @@
88
from colorama import Fore, Style
99
from utils.logger import get_logger
1010
from utils.system import get_platform_firewall_command
11-
import ipaddress
11+
import ipaddress
1212

1313

1414
class IPBlocker:
1515
"""Handles IP blocking operations across different platforms"""
16-
16+
1717
def __init__(self, block_duration: int, whitelist: Set[str]):
1818
self.block_duration = block_duration
19-
# Normalize whitelist (support IPv4, IPv6, CIDR)
20-
self.whitelist = {str(ipaddress.ip_network(ip, strict=False)) for ip in whitelist}
19+
self.whitelist = whitelist
2120
self.blocked_ips: Dict[str, datetime] = {}
2221
self.lock = threading.Lock()
2322
self.logger = get_logger(__name__)
2423
self.platform = platform.system().lower()
2524
self.firewall_cmd = get_platform_firewall_command()
2625

26+
# Added helper for IP validation
2727
def _validate_ip(self, ip: str) -> str:
28-
"""Validate and normalize an IP or network string (IPv4/IPv6/CIDR supported)"""
28+
"""
29+
Validate and normalize an IP or CIDR (IPv4 or IPv6).
30+
Raises ValueError if invalid.
31+
"""
2932
try:
3033
network = ipaddress.ip_network(ip, strict=False)
3134
return str(network)
3235
except ValueError as e:
3336
self.logger.error(f"Invalid IP/network '{ip}': {e}")
3437
raise ValueError(f"Invalid IP/network: {ip}")
3538

36-
def _is_whitelisted(self, ip: str) -> bool:
37-
"""Check if IP is within whitelist networks"""
38-
try:
39-
ip_net = ipaddress.ip_network(ip, strict=False)
40-
return any(ip_net.subnet_of(ipaddress.ip_network(wl)) for wl in self.whitelist)
41-
except ValueError:
42-
return False
43-
4439
def block_ip(self, ip: str, reason: str) -> bool:
4540
"""Block an IP address using the appropriate firewall system"""
4641
try:
47-
ip = self._validate_ip(ip)
42+
ip = self._validate_ip(ip) # Validate IP first
4843
except ValueError:
4944
return False
5045

5146
if self._is_whitelisted(ip):
5247
self.logger.info(f"IP {ip} is whitelisted, not blocking")
5348
return False
54-
49+
5550
with self.lock:
5651
if ip not in self.blocked_ips:
5752
try:
5853
success = self._execute_block_command(ip)
59-
54+
6055
if success:
6156
self.blocked_ips[ip] = datetime.now()
6257
self.logger.warning(f"🚫 BLOCKED IP: {ip} - Reason: {reason}")
@@ -65,26 +60,26 @@ def block_ip(self, ip: str, reason: str) -> bool:
6560
else:
6661
self.logger.error(f"Failed to block IP {ip}")
6762
return False
68-
63+
6964
except Exception as e:
7065
self.logger.error(f"Exception while blocking IP {ip}: {e}")
7166
return False
7267
else:
7368
self.logger.debug(f"IP {ip} already blocked")
7469
return True
75-
70+
7671
def unblock_ip(self, ip: str) -> bool:
7772
"""Manually unblock a specific IP address"""
7873
try:
79-
ip = self._validate_ip(ip)
74+
ip = self._validate_ip(ip) # Validate IP before unblocking
8075
except ValueError:
8176
return False
8277

8378
with self.lock:
8479
if ip in self.blocked_ips:
8580
try:
8681
success = self._execute_unblock_command(ip)
87-
82+
8883
if success:
8984
del self.blocked_ips[ip]
9085
self.logger.info(f"✅ UNBLOCKED IP: {ip}")
@@ -93,7 +88,7 @@ def unblock_ip(self, ip: str) -> bool:
9388
else:
9489
self.logger.error(f"Failed to unblock IP {ip}")
9590
return False
96-
91+
9792
except Exception as e:
9893
self.logger.error(f"Exception while unblocking IP {ip}: {e}")
9994
return False
@@ -106,19 +101,23 @@ def unblock_expired_ips(self) -> List[str]:
106101
current_time = datetime.now()
107102
block_duration = timedelta(seconds=self.block_duration)
108103
unblocked_ips = []
109-
104+
110105
with self.lock:
111106
expired_ips = [
112107
ip for ip, block_time in self.blocked_ips.items()
113108
if current_time - block_time > block_duration
114109
]
115-
110+
116111
for ip in expired_ips:
117112
if self.unblock_ip(ip):
118113
unblocked_ips.append(ip)
119-
114+
120115
return unblocked_ips
121-
116+
117+
def _is_whitelisted(self, ip: str) -> bool:
118+
"""Check if IP is in whitelist"""
119+
return ip in self.whitelist
120+
122121
def _execute_block_command(self, ip: str) -> bool:
123122
"""Execute platform-specific block command"""
124123
try:
@@ -134,7 +133,7 @@ def _execute_block_command(self, ip: str) -> bool:
134133
except Exception as e:
135134
self.logger.error(f"Platform-specific blocking failed: {e}")
136135
return False
137-
136+
138137
def _execute_unblock_command(self, ip: str) -> bool:
139138
"""Execute platform-specific unblock command"""
140139
try:
@@ -150,58 +149,73 @@ def _execute_unblock_command(self, ip: str) -> bool:
150149
except Exception as e:
151150
self.logger.error(f"Platform-specific unblocking failed: {e}")
152151
return False
153-
154-
# -------- PLATFORM-SPECIFIC IMPLEMENTATIONS -------- #
155-
152+
156153
def _block_ip_linux(self, ip: str) -> bool:
157154
"""Block IP using iptables on Linux"""
158155
cmd = ['sudo', 'iptables', '-A', 'INPUT', '-s', ip, '-j', 'DROP']
159-
result = subprocess.run(cmd, capture_output=True, text=True)
156+
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
160157
return result.returncode == 0
161-
158+
162159
def _unblock_ip_linux(self, ip: str) -> bool:
163160
"""Unblock IP using iptables on Linux"""
164161
cmd = ['sudo', 'iptables', '-D', 'INPUT', '-s', ip, '-j', 'DROP']
165162
result = subprocess.run(cmd, capture_output=True, text=True)
166163
return result.returncode == 0
167-
164+
168165
def _block_ip_macos(self, ip: str) -> bool:
169166
"""Block IP using pfctl on macOS"""
170167
cmd1 = ['sudo', 'pfctl', '-t', 'blocked_ips', '-T', 'add', ip]
171168
result1 = subprocess.run(cmd1, capture_output=True, text=True)
172-
subprocess.run(['sudo', 'pfctl', '-e'], capture_output=True, text=True)
169+
cmd2 = ['sudo', 'pfctl', '-e']
170+
subprocess.run(cmd2, capture_output=True, text=True)
173171
return result1.returncode == 0
174-
172+
175173
def _unblock_ip_macos(self, ip: str) -> bool:
176174
"""Unblock IP using pfctl on macOS"""
177175
cmd = ['sudo', 'pfctl', '-t', 'blocked_ips', '-T', 'delete', ip]
178176
result = subprocess.run(cmd, capture_output=True, text=True)
179177
return result.returncode == 0
180-
178+
181179
def _block_ip_windows(self, ip: str) -> bool:
182180
"""Block IP using Windows Firewall (netsh)"""
183-
rule_name = f"SimpleFirewall_Block_{ip.replace(':', '_').replace('.', '_')}"
181+
rule_name = f"SimpleFirewall_Block_{ip.replace('.', '_')}"
184182
cmd = [
185183
'netsh', 'advfirewall', 'firewall', 'add', 'rule',
186184
f'name={rule_name}',
187185
'dir=in',
188186
'action=block',
189187
f'remoteip={ip}'
190188
]
191-
result = subprocess.run(cmd, capture_output=True, text=True)
192-
return result.returncode == 0
189+
try:
190+
result = subprocess.run(cmd, capture_output=True, text=True)
191+
if result.returncode == 0:
192+
self.logger.debug(f"netsh add rule stdout: {result.stdout.strip()}")
193+
return True
194+
else:
195+
self.logger.error(f"netsh add rule failed: rc={result.returncode} stdout={result.stdout.strip()} stderr={result.stderr.strip()}")
196+
return False
197+
except Exception as e:
198+
self.logger.error(f"Exception when running netsh add rule: {e}")
199+
return False
193200

194201
def _unblock_ip_windows(self, ip: str) -> bool:
195202
"""Unblock IP using Windows Firewall (netsh)"""
196-
rule_name = f"SimpleFirewall_Block_{ip.replace(':', '_').replace('.', '_')}"
203+
rule_name = f"SimpleFirewall_Block_{ip.replace('.', '_')}"
197204
cmd = [
198205
'netsh', 'advfirewall', 'firewall', 'delete', 'rule',
199206
f'name={rule_name}'
200207
]
201-
result = subprocess.run(cmd, capture_output=True, text=True)
202-
return result.returncode == 0
203-
204-
# --------------------------------------------------- #
208+
try:
209+
result = subprocess.run(cmd, capture_output=True, text=True)
210+
if result.returncode == 0:
211+
self.logger.debug(f"netsh delete rule stdout: {result.stdout.strip()}")
212+
return True
213+
else:
214+
self.logger.error(f"netsh delete rule failed: rc={result.returncode} stdout={result.stdout.strip()} stderr={result.stderr.strip()}")
215+
return False
216+
except Exception as e:
217+
self.logger.error(f"Exception when running netsh delete rule: {e}")
218+
return False
205219

206220
def get_blocked_ips(self) -> Dict[str, str]:
207221
"""Get currently blocked IPs with their block times"""
@@ -210,16 +224,18 @@ def get_blocked_ips(self) -> Dict[str, str]:
210224
ip: block_time.isoformat()
211225
for ip, block_time in self.blocked_ips.items()
212226
}
213-
227+
214228
def cleanup_all_blocks(self) -> List[str]:
215229
"""Remove all blocks (useful for shutdown)"""
216230
cleaned_ips = []
231+
217232
with self.lock:
218233
for ip in list(self.blocked_ips.keys()):
219234
if self.unblock_ip(ip):
220235
cleaned_ips.append(ip)
236+
221237
return cleaned_ips
222-
238+
223239
def get_stats(self) -> Dict[str, int]:
224240
"""Get blocking statistics"""
225241
with self.lock:

0 commit comments

Comments
 (0)