diff --git a/httpz_scanner/__init__.py b/httpz_scanner/__init__.py index 25900a5..5f2529d 100644 --- a/httpz_scanner/__init__.py +++ b/httpz_scanner/__init__.py @@ -6,4 +6,4 @@ from .colors import Colors from .scanner import HTTPZScanner -__version__ = '2.1.0' \ No newline at end of file +__version__ = '2.1.1' \ No newline at end of file diff --git a/httpz_scanner/scanner.py b/httpz_scanner/scanner.py index 27b0453..5343194 100644 --- a/httpz_scanner/scanner.py +++ b/httpz_scanner/scanner.py @@ -198,72 +198,110 @@ class HTTPZScanner: async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: tasks = set() - count = 0 # Move counter here since that's all process_result was doing + domain_queue = asyncio.Queue() + queue_empty = False - # Handle different input types - if isinstance(input_source, str): - # File or stdin input - gen = input_generator(input_source, self.shard) - async for domain in gen: - if len(tasks) >= self.concurrent_limit: - done, tasks = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - for task in done: - if result := await task: - if self.show_progress: - count += 1 - yield result - - task = asyncio.create_task(self.check_domain(session, domain)) - tasks.add(task) - - # List/tuple input - elif isinstance(input_source, (list, tuple)): - for line_num, domain in enumerate(input_source): - if domain := str(domain).strip(): - if self.shard is None or line_num % self.shard[1] == self.shard[0]: - if len(tasks) >= self.concurrent_limit: - done, tasks = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - for task in done: - if result := await task: - if self.show_progress: - count += 1 - yield result - - task = asyncio.create_task(self.check_domain(session, domain)) - tasks.add(task) - else: - # Async generator input - line_num = 0 - async for domain in input_source: - if isinstance(domain, bytes): - domain = domain.decode() - domain = domain.strip() - - if domain: - if self.shard is None or line_num % self.shard[1] == self.shard[0]: - if len(tasks) >= self.concurrent_limit: - done, tasks = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - for task in done: - if result := await task: - if self.show_progress: - count += 1 - yield result - - task = asyncio.create_task(self.check_domain(session, domain)) - tasks.add(task) - line_num += 1 - - # Process remaining tasks - if tasks: - done, _ = await asyncio.wait(tasks) - for task in done: - if result := await task: + async def process_domain(domain): + try: + result = await self.check_domain(session, domain) + if result: if self.show_progress: - count += 1 - yield result \ No newline at end of file + self.progress_count += 1 + return result + except Exception as e: + debug(f'Error processing {domain}: {str(e)}') + return None + + # Add domains to queue based on input type + async def queue_domains(): + try: + if isinstance(input_source, str): + # File or stdin input + gen = input_generator(input_source, self.shard) + async for domain in gen: + await domain_queue.put(domain) + + elif isinstance(input_source, (list, tuple)): + # List/tuple input + for line_num, domain in enumerate(input_source): + if domain := str(domain).strip(): + if self.shard is None or line_num % self.shard[1] == self.shard[0]: + await domain_queue.put(domain) + + else: + # Async generator input + line_num = 0 + async for domain in input_source: + if isinstance(domain, bytes): + domain = domain.decode() + if domain := domain.strip(): + if self.shard is None or line_num % self.shard[1] == self.shard[0]: + await domain_queue.put(domain) + line_num += 1 + except Exception as e: + debug(f'Error queuing domains: {str(e)}') + finally: + # Signal queue completion + await domain_queue.put(None) + + # Start domain queuing task + queue_task = asyncio.create_task(queue_domains()) + + try: + while not queue_empty or tasks: + # Fill up tasks to concurrent_limit + while len(tasks) < self.concurrent_limit and not queue_empty: + try: + domain = await domain_queue.get() + if domain is None: # Queue is empty + queue_empty = True + break + task = asyncio.create_task(process_domain(domain)) + tasks.add(task) + except asyncio.CancelledError: + break + except Exception as e: + debug(f'Error creating task: {str(e)}') + + if not tasks: + break + + # Wait for any task to complete with timeout + try: + done, pending = await asyncio.wait( + tasks, + timeout=self.timeout, + return_when=asyncio.FIRST_COMPLETED + ) + + # Handle completed tasks + for task in done: + tasks.remove(task) + try: + if result := await task: + yield result + except Exception as e: + debug(f'Error processing task result: {str(e)}') + + # Handle timed out tasks + if not done and pending: + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + tasks.remove(task) + + except Exception as e: + debug(f'Error in task processing loop: {str(e)}') + + finally: + # Clean up + for task in tasks: + task.cancel() + queue_task.cancel() + try: + await queue_task + except asyncio.CancelledError: + pass \ No newline at end of file diff --git a/setup.py b/setup.py index 8c1954e..9e133fa 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ with open('README.md', 'r', encoding='utf-8') as f: setup( name='httpz_scanner', - version='2.1.0', + version='2.1.1', author='acidvegas', author_email='acid.vegas@acid.vegas', description='Hyper-fast HTTP Scraping Tool', diff --git a/unit_test.py b/unit_test.py index 68ef402..b83b304 100644 --- a/unit_test.py +++ b/unit_test.py @@ -78,7 +78,7 @@ async def test_list_input(domains): ''' logging.info(f'{Colors.BOLD}Testing list input...{Colors.RESET}') - scanner = HTTPZScanner(concurrent_limit=20, timeout=3, show_progress=True, debug_mode=True) + scanner = HTTPZScanner(concurrent_limit=100, timeout=3, show_progress=True, debug_mode=True) count = 0 async for result in scanner.scan(domains): @@ -96,7 +96,7 @@ async def test_generator_input(domains): ''' logging.info(f'{Colors.BOLD}Testing generator input...{Colors.RESET}') - scanner = HTTPZScanner(concurrent_limit=20, timeout=3, show_progress=True, debug_mode=True) + scanner = HTTPZScanner(concurrent_limit=100, timeout=3, show_progress=True, debug_mode=True) count = 0 async for result in scanner.scan(domain_generator(domains)): @@ -115,8 +115,8 @@ async def main() -> None: logging.info(f'Loaded {Colors.YELLOW}{len(domains)}{Colors.RESET} domains for testing') # Run tests - await test_list_input(domains) await test_generator_input(domains) + await test_list_input(domains) logging.info(f'{Colors.GREEN}All tests completed successfully!{Colors.RESET}')