diff --git a/AutoBlockIPList.py b/AutoBlockIPList.py index 8972f3c..1149f45 100644 --- a/AutoBlockIPList.py +++ b/AutoBlockIPList.py @@ -10,181 +10,293 @@ import ipaddress import time from functools import reduce - - -VERSION = "1.0.1" - - -def create_connection(db_file): - try: - return sqlite3.connect(db_file) - except sqlite3.Error as e: - raise e - - -def get_ip_remote(link): - data = "" - try: - r = requests.get(link) - r.raise_for_status() - data = r.text.replace("\r", "") - except requests.exceptions.RequestException as e: - verbose(f"Unable to connect to {link}") - return data - - -def get_ip_local(file): - return file.read().replace("\r", "") - - -def get_ip_list(local, external): - data = [get_ip_local(f).split("\n") for f in local] + [get_ip_remote(s).split("\n") for s in external] - ip = reduce(lambda a, b: a + b, data) - return ip - - -def ipv4_to_ipstd(ipv4): - numbers = [int(bits) for bits in ipv4.split('.')] - return '0000:0000:0000:0000:0000:ffff:{:02x}{:02x}:{:02x}{:02x}'.format(*numbers).upper() - - -def ipv6_to_ipstd(ipv6): - return ipaddress.ip_address(ipv6).exploded.upper() - - -def process_ip(ip_list, expire): - processed = [] - invalid = [] - for i in ip_list: +from typing import List, Tuple, Union, Optional, TextIO +from pathlib import Path +import logging +from dataclasses import dataclass +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed + +VERSION = "1.1.0" + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler()] +) +logger = logging.getLogger(__name__) + +@dataclass +class IPEntry: + """Data class for IP address entries""" + ip: str + ip_std: str + expire: int + +class DatabaseManager: + """Handles all database operations""" + def __init__(self, db_path: str): + self.db_path = db_path + self._validate_db() + + def _validate_db(self) -> None: + """Validate database file exists and is accessible""" + if not Path(self.db_path).is_file(): + raise FileNotFoundError(f"No such file or directory: '{self.db_path}'") + if not os.access(self.db_path, os.R_OK): + raise PermissionError("Unable to read database. Run this script with sudo or root user.") + + def create_connection(self) -> sqlite3.Connection: + """Create a database connection""" try: - ip = ipaddress.ip_address(i) - if ip.version == 4: - ipstd = ipv4_to_ipstd(i) - elif ip.version == 6: - ipstd = ipv6_to_ipstd(i) + return sqlite3.connect(self.db_path) + except sqlite3.Error as e: + logger.error(f"Database connection error: {e}") + raise + + def execute_query(self, query: str, params: tuple = None, many: bool = False) -> None: + """Execute a single SQL query""" + with self.create_connection() as conn: + cursor = conn.cursor() + if many and params: + cursor.executemany(query, params) + elif params: + cursor.execute(query, params) else: - ipstd = "" - processed.append([i, ipstd, expire]) - except ValueError: - if i != "": - invalid.append(i) - return processed, invalid - - -def url(link): - if validators.url(link) != True: - raise argparse.ArgumentError - return link - - -def folder(attr='r'): - def check_folder(path): - if os.path.isdir(path): - if attr == 'r' and not os.access(path, os.R_OK): - raise argparse.ArgumentTypeError(f'"{path}" is not readable.') - if attr == 'w' and not os.access(path, os.W_OK): - raise argparse.ArgumentTypeError(f'"{path}" is not writable.') - return os.path.abspath(path) - else: - raise argparse.ArgumentTypeError(f'"{path}" is not a valid path.') - return check_folder - - -def verbose(message): - global args - if args.verbose: - print(message) - - -def parse_args(): - parser = argparse.ArgumentParser(prog='AutoBlockIPList') - parser.add_argument("-f","--in-file", nargs='*', type=argparse.FileType('r'), default=[], - help="Local list file separated by a space " - "(eg. /home/user/list.txt custom.txt)") - parser.add_argument("-u", "--in-url", nargs="*", type=url, default=[], - help="External list url separated by a space " - "(eg https://example.com/list.txt https://example.com/all.txt)") - parser.add_argument("-e", "--expire-in-day", type=int, default=0, - help="Expire time in day. Default 0: no expiration") - parser.add_argument("--remove-expired", action='store_true', - help="Remove expired entry") - parser.add_argument("-b", "--backup-to", type=folder('w'), - help="Folder to store a backup of the database") - parser.add_argument("--clear-db", action='store_true', - help="Clear ALL deny entry in database before filling") - parser.add_argument("--dry-run", action='store_true', - help="Perform a run without any modifications") - parser.add_argument("-v", "--verbose", action='store_true', - help="Increase output verbosity") - parser.add_argument('--version', action='version', version=f'%(prog)s version {VERSION}') - - a = parser.parse_args() - - if len(a.in_file) == 0 and len(a.in_url) == 0: - raise parser.error("At least one source list is mandatory (file or url)") - if a.clear_db and a.backup_to is None: - raise parser.error("backup folder should be set for clear db") - if a.dry_run: - a.verbose = True - - return a - + cursor.execute(query) + conn.commit() + + def fetch_one(self, query: str, params: tuple = None) -> Optional[tuple]: + """Fetch a single result from database""" + with self.create_connection() as conn: + cursor = conn.cursor() + cursor.execute(query, params) if params else cursor.execute(query) + return cursor.fetchone() + +class IPProcessor: + """Handles IP address processing and validation""" + @staticmethod + def ipv4_to_ipstd(ipv4: str) -> str: + """Convert IPv4 to standardized IPv6 format""" + numbers = [int(bits) for bits in ipv4.split('.')] + return f"0000:0000:0000:0000:0000:ffff:{numbers[0]:02x}{numbers[1]:02x}:{numbers[2]:02x}{numbers[3]:02x}".upper() + + @staticmethod + def ipv6_to_ipstd(ipv6: str) -> str: + """Convert IPv6 to standardized format""" + return ipaddress.ip_address(ipv6).exploded.upper() + + @classmethod + def process_ip_list(cls, ip_list: List[str], expire: int) -> Tuple[List[IPEntry], List[str]]: + """Process raw IP list into structured data""" + processed = [] + invalid = [] + + for ip in filter(None, ip_list): # Filter out empty strings + try: + ip_obj = ipaddress.ip_address(ip.strip()) + ipstd = cls.ipv4_to_ipstd(ip) if ip_obj.version == 4 else cls.ipv6_to_ipstd(ip) + processed.append(IPEntry(ip, ipstd, expire)) + except ValueError: + invalid.append(ip) + + return processed, invalid + +class IPFetcher: + """Handles fetching IP addresses from various sources""" + @staticmethod + def fetch_remote(url: str, timeout: int = 10) -> str: + """Fetch IP list from remote URL""" + try: + response = requests.get(url, timeout=timeout) + response.raise_for_status() + return response.text.replace("\r", "") + except requests.RequestException as e: + logger.warning(f"Failed to fetch from {url}: {e}") + return "" + + @staticmethod + def fetch_local(file: TextIO) -> str: + """Read IP list from local file""" + return file.read().replace("\r", "") + + @classmethod + def fetch_all(cls, local_files: List[TextIO], remote_urls: List[str], max_workers: int = 5) -> List[str]: + """Fetch IP lists from all sources concurrently""" + ip_lists = [] + + # Process local files + for file in local_files: + ip_lists.append(cls.fetch_local(file).split("\n")) + + # Process remote URLs concurrently + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_url = { + executor.submit(cls.fetch_remote, url): url + for url in remote_urls + } + for future in as_completed(future_to_url): + ip_lists.append(future.result().split("\n")) + + # Flatten the list of lists + return list(set(ip for sublist in ip_lists for ip in sublist if ip.strip())) + +class BackupManager: + """Handles database backup operations""" + @staticmethod + def create_backup(db_path: str, backup_dir: str) -> str: + """Create a timestamped backup of the database""" + backup_dir = Path(backup_dir) + if not backup_dir.exists(): + backup_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = backup_dir / f"{timestamp}_synoautoblock_backup.db" + + try: + shutil.copy2(db_path, backup_path) + logger.info(f"Database backup created at {backup_path}") + return str(backup_path) + except IOError as e: + logger.error(f"Backup failed: {e}") + raise + +class ArgumentParser: + """Handles command line argument parsing""" + @staticmethod + def url(link: str) -> str: + """Validate URL argument""" + if validators.url(link) is not True: + raise argparse.ArgumentTypeError(f"Invalid URL: {link}") + return link + + @staticmethod + def folder(access: str = 'r') -> Path: + """Validate folder path argument""" + def check_folder(path: str) -> Path: + path_obj = Path(path) + if not path_obj.is_dir(): + raise argparse.ArgumentTypeError(f"'{path}' is not a valid directory") + if access == 'r' and not os.access(path, os.R_OK): + raise argparse.ArgumentTypeError(f"'{path}' is not readable") + if access == 'w' and not os.access(path, os.W_OK): + raise argparse.ArgumentTypeError(f"'{path}' is not writable") + return path_obj + return check_folder + + @classmethod + def parse_args(cls) -> argparse.Namespace: + """Parse command line arguments""" + parser = argparse.ArgumentParser(prog='AutoBlockIPList', + description='Advanced IP blocking utility for Synology NAS') + + # Input sources + input_group = parser.add_argument_group('Input Sources') + input_group.add_argument("-f", "--in-file", nargs='*', type=argparse.FileType('r'), default=[], + help="Local IP list files (space separated)") + input_group.add_argument("-u", "--in-url", nargs="*", type=cls.url, default=[], + help="Remote IP list URLs (space separated)") + + # Database operations + db_group = parser.add_argument_group('Database Operations') + db_group.add_argument("-e", "--expire-in-day", type=int, default=0, + help="Expiration time in days (0 = no expiration)") + db_group.add_argument("--remove-expired", action='store_true', + help="Remove expired entries from database") + db_group.add_argument("--clear-db", action='store_true', + help="Clear ALL deny entries before processing") + + # Additional options + opt_group = parser.add_argument_group('Additional Options') + opt_group.add_argument("-b", "--backup-to", type=cls.folder('w'), + help="Directory for database backups") + opt_group.add_argument("--max-workers", type=int, default=5, + help="Maximum concurrent workers for URL fetching") + opt_group.add_argument("--dry-run", action='store_true', + help="Simulate without making changes") + opt_group.add_argument("-v", "--verbose", action='store_true', + help="Enable verbose output") + opt_group.add_argument("--version", action='version', + version=f'%(prog)s version {VERSION}') + + args = parser.parse_args() + + # Validation + if not args.in_file and not args.in_url: + parser.error("At least one input source is required (file or URL)") + if args.clear_db and not args.backup_to: + parser.error("Backup directory is required when using --clear-db") + if args.dry_run: + args.verbose = True + + # Set logging level + logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) + + return args + +def main(): + """Main execution function""" + start_time = time.monotonic() + args = ArgumentParser.parse_args() + + try: + # Initialize components + db_path = "/etc/synoautoblock.db" + db_manager = DatabaseManager(db_path) + ip_fetcher = IPFetcher() + ip_processor = IPProcessor() + + # Create backup if requested + if args.backup_to: + BackupManager.create_backup(db_path, args.backup_to) + + # Calculate expiration timestamp + expire_time = int(time.time()) + args.expire_in_day * 86400 if args.expire_in_day > 0 else 0 + + # Fetch and process IP lists + raw_ips = ip_fetcher.fetch_all(args.in_file, args.in_url, args.max_workers) + ip_entries, invalid_ips = ip_processor.process_ip_list(raw_ips, expire_time) + + logger.info(f"Processing {len(ip_entries)} valid IPs (found {len(invalid_ips)} invalid entries)") + + if invalid_ips and args.verbose: + logger.debug(f"Invalid IPs found: {', '.join(invalid_ips)}") + + # Database operations + if ip_entries: + current_count = db_manager.fetch_one("SELECT COUNT(IP) FROM AutoBlockIP WHERE Deny = 1")[0] + logger.info(f"Current blocked IPs in database: {current_count}") + + if args.remove_expired and not args.dry_run: + db_manager.execute_query( + "DELETE FROM AutoBlockIP WHERE Deny = 1 AND ExpireTime > 0 AND ExpireTime < strftime('%s','now')" + ) + logger.info("Expired entries removed") + + if args.clear_db and not args.dry_run: + db_manager.execute_query("DELETE FROM AutoBlockIP WHERE Deny = 1") + logger.info("All deny entries cleared") + + if not args.dry_run: + db_manager.execute_query( + """REPLACE INTO AutoBlockIP + (IP, IPStd, ExpireTime, Deny, RecordTime, Type, Meta) + VALUES(?, ?, ?, 1, strftime('%s','now'), 0, NULL)""", + [(entry.ip, entry.ip_std, entry.expire) for entry in ip_entries], + many=True + ) + new_count = db_manager.fetch_one("SELECT COUNT(IP) FROM AutoBlockIP WHERE Deny = 1")[0] + logger.info(f"Database updated. New blocked IP count: {new_count} (+{new_count - current_count})") + else: + logger.info("Dry run complete - no changes made to database") + + elapsed = time.monotonic() - start_time + logger.info(f"Operation completed in {elapsed:.2f} seconds") + + except Exception as e: + logger.error(f"Fatal error: {e}") + sys.exit(1) if __name__ == '__main__': - start_time = time.time() - args = parse_args() - - # define the path of the database - # DSM 6: "/etc/synoautoblock.db" - # DSM 7: should be the same (TODO confirm path) - db = "/etc/synoautoblock.db" - - if not os.path.isfile(db): - raise FileNotFoundError(f"No such file or directory: '{db}'") - if not os.access(db, os.R_OK): - raise FileExistsError("Unable to read database. Run this script with sudo or root user.") - - if args.backup_to is not None: - filename = datetime.now().strftime("%Y%d%m_%H%M%S") + "_backup_synoautoblock.db" - shutil.copy(db, os.path.join(args.backup_to, filename)) - verbose("Database successfully backup") - - if args.expire_in_day > 0: - args.expire_in_day = int(start_time) + args.expire_in_day * 60 * 60 * 24 - - ips = get_ip_list(args.in_file, args.in_url) - ips_formatted, ips_invalid = process_ip(ips, args.expire_in_day) - - verbose(f"Total IP fetch in lists: {len(ips_formatted)}") - - if len(ips_formatted) > 0: - conn = create_connection(db) - c = conn.cursor() - - if args.remove_expired and not args.dry_run: - c.execute("DELETE FROM AutoBlockIP WHERE Deny = 1 AND ExpireTime > 0 AND ExpireTime < strftime('%s','now')") - verbose("All expired entry was successfully removed") - - if args.clear_db and not args.dry_run: - c.execute("DELETE FROM AutoBlockIP WHERE Deny = 1") - verbose("All deny entry was successfully removed") - - nb_ip = c.execute("SELECT COUNT(IP) FROM AutoBlockIP WHERE Deny = 1") - nb_ip_before = nb_ip.fetchone()[0] - - verbose(f"Total deny IP currently in your Synology DB: {nb_ip_before}") - if not args.dry_run: - c.executemany("REPLACE INTO AutoBlockIP (IP, IPStd, ExpireTime, Deny, RecordTime, Type, Meta) " - "VALUES(?, ?, ?, 1, strftime('%s','now'), 0, NULL);", ips_formatted) - else: - verbose("Dry run -> nothing to do") - nb_ip = c.execute("SELECT COUNT(IP) FROM AutoBlockIP WHERE Deny = 1") - nb_ip_after = nb_ip.fetchone()[0] - conn.commit() - conn.close() - verbose(f"Total deny IP now in your Synology DB: {nb_ip_after} ({nb_ip_after - nb_ip_before} added)") - else: - verbose("No IP found in list") - - elapsed = round(time.time() - start_time, 2) - verbose(f"Elapsed time: {elapsed} seconds") + main()