diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 66bf7c7049..cd8cded1b1 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3441,6 +3441,57 @@ def _clear_watcher(conn, expiring_weakref): pass +def _fetch_remaining_pages(connection, query_msg, timeout, fail_on_error=True): + """ + Fetch all pages for a paged query. + Executes the query and fetches all pages if the result is paged. + + :param connection: The connection to use for querying + :param query_msg: The QueryMessage to execute (must have fetch_size set for paging) + :param timeout: Timeout for each query operation + :param fail_on_error: If True, raise exceptions on query failure. If False, return (success, result) tuple. Defaults to True (same as connection.wait_for_response) + :return: If fail_on_error=True, returns the result with all parsed_rows combined from all pages. + If fail_on_error=False, returns (success, result) tuple where result has all parsed_rows combined. + """ + # Execute the query to get the first page + response = connection.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error) + + # Handle fail_on_error=False case where response is (success, result) tuple + if not fail_on_error: + success, result = response + if not success: + return response # Return (False, exception) tuple + else: + result = response + + if not result or not result.paging_state: + return response if not fail_on_error else result + + all_rows = list(result.parsed_rows) if result.parsed_rows else [] + + # Fetch remaining pages + while result and result.paging_state: + query_msg.paging_state = result.paging_state + page_response = connection.wait_for_response(query_msg, timeout=timeout, fail_on_error=fail_on_error) + + if not fail_on_error: + page_success, page_result = page_response + if not page_success: + return page_response # Return (False, exception) tuple + result = page_result + else: + result = page_response + + if result and result.parsed_rows: + all_rows.extend(result.parsed_rows) + + # Update the result with all rows + if result: + result.parsed_rows = all_rows + + return (True, result) if not fail_on_error else result + + class ControlConnection(object): """ Internal @@ -3638,23 +3689,31 @@ def _try_connect(self, host): sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) + consistency_level=ConsistencyLevel.ONE, + fetch_size=self._schema_meta_page_size) local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - (peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses( - peers_query, local_query, timeout=self._timeout, fail_on_error=False) - - if not local_success: - raise local_result - + consistency_level=ConsistencyLevel.ONE, + fetch_size=self._schema_meta_page_size) + + # Try to execute peers query (might be peers_v2) + # Use fail_on_error=False to handle peers_v2 fallback gracefully + peers_success, peers_result = _fetch_remaining_pages(connection, peers_query, self._timeout, fail_on_error=False) + if not peers_success: # error with the peers v2 query, fallback to peers v1 self._uses_peers_v2 = False sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - peers_result = connection.wait_for_response( - peers_query, timeout=self._timeout) + consistency_level=ConsistencyLevel.ONE, + fetch_size=self._schema_meta_page_size) + peers_result = _fetch_remaining_pages(connection, peers_query, self._timeout) + + # Fetch local query (note: system.local always has exactly 1 row, so it will never have additional pages) + # Use fail_on_error=False to match original behavior + local_success, local_result = _fetch_remaining_pages(connection, local_query, self._timeout, fail_on_error=False) + + if not local_success: + raise local_result shared_results = (peers_result, local_result) self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) @@ -3797,11 +3856,17 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, log.debug("[control connection] Refreshing node list and token map") sel_local = self._SELECT_LOCAL peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=cl) + consistency_level=cl, + fetch_size=self._schema_meta_page_size) local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout), - consistency_level=cl) - peers_result, local_result = connection.wait_for_responses( - peers_query, local_query, timeout=self._timeout) + consistency_level=cl, + fetch_size=self._schema_meta_page_size) + + # Fetch all pages for both queries + # Note: system.local always has exactly 1 row, so it will never have additional pages + # system.peers might have multiple pages for very large clusters (>1000 nodes) + peers_result = _fetch_remaining_pages(connection, peers_query, self._timeout) + local_result = _fetch_remaining_pages(connection, local_query, self._timeout) peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) @@ -3856,9 +3921,11 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, # in system.local. See CASSANDRA-9436. local_rpc_address_query = QueryMessage( query=maybe_add_timeout_to_query(self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - success, local_rpc_address_result = connection.wait_for_response( - local_rpc_address_query, timeout=self._timeout, fail_on_error=False) + consistency_level=ConsistencyLevel.ONE, + fetch_size=self._schema_meta_page_size) + # Fetch all pages (system.local table always contains exactly one row, so this is effectively a no-op) + success, local_rpc_address_result = _fetch_remaining_pages( + connection, local_rpc_address_query, self._timeout, fail_on_error=False) if success: row = dict_factory( local_rpc_address_result.column_names, @@ -4092,13 +4159,19 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai while elapsed < total_timeout: peers_query = QueryMessage(query=maybe_add_timeout_to_query(select_peers_query, self._metadata_request_timeout), - consistency_level=cl) + consistency_level=cl, + fetch_size=self._schema_meta_page_size) local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout), - consistency_level=cl) + consistency_level=cl, + fetch_size=self._schema_meta_page_size) try: timeout = min(self._timeout, total_timeout - elapsed) - peers_result, local_result = connection.wait_for_responses( - peers_query, local_query, timeout=timeout) + + # Fetch all pages if there are more results + # Note: system.local always has exactly 1 row, so it will never have additional pages + # system.peers might have multiple pages for very large clusters (>1000 nodes) + peers_result = _fetch_remaining_pages(connection, peers_query, timeout) + local_result = _fetch_remaining_pages(connection, local_query, timeout) except OperationTimedOut as timeout: log.debug("[control connection] Timed out waiting for " "response during schema agreement check: %s", timeout) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index d759e12332..0061d76317 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -15,10 +15,10 @@ import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, ANY, call +from unittest.mock import Mock, ANY, call, MagicMock -from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType -from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS +from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType, ConsistencyLevel +from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS, QueryMessage from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory @@ -167,6 +167,20 @@ def __init__(self): ["192.168.1.2", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] ] self.wait_for_responses = Mock(return_value=_node_meta_results(self.local_results, self.peer_results)) + # Set up wait_for_response to return the appropriate result based on the query + def wait_for_response_side_effect(query_msg, timeout=None, fail_on_error=True): + # Create a result that matches the expected format + result = ResultMessage(kind=RESULT_KIND_ROWS) + # Return peer or local results based on some simple heuristic + if "peers" in query_msg.query.lower(): + result.column_names = self.peer_results[0] + result.parsed_rows = self.peer_results[1] + else: + result.column_names = self.local_results[0] + result.parsed_rows = self.local_results[1] + result.paging_state = None + return result + self.wait_for_response = Mock(side_effect=wait_for_response_side_effect) class FakeTime(object): @@ -305,6 +319,68 @@ def test_refresh_nodes_and_tokens(self): assert self.connection.wait_for_responses.call_count == 1 + def test_topology_queries_use_paging(self): + """ + Test that topology queries (system.peers and system.local) use fetch_size parameter + """ + # Test during refresh_node_list_and_token_map + self.control_connection.refresh_node_list_and_token_map() + + # Verify that wait_for_response was called (now used instead of wait_for_responses) + assert self.connection.wait_for_response.called + + # Get the QueryMessage arguments from the calls + calls = self.connection.wait_for_response.call_args_list + + # Verify QueryMessage instances have fetch_size set + for call in calls: + query_msg = call[0][0] # First positional argument + assert isinstance(query_msg, QueryMessage) + assert query_msg.fetch_size == self.control_connection._schema_meta_page_size + + def test_topology_queries_fetch_all_pages(self): + """ + Test that topology queries fetch all pages when results are paged + """ + from cassandra.cluster import _fetch_remaining_pages + + # Create mock connection + mock_connection = MagicMock() + mock_connection.endpoint = DefaultEndPoint("192.168.1.0") + mock_connection.original_endpoint = mock_connection.endpoint + + # Create first page of peers results with paging_state + first_page = ResultMessage(kind=RESULT_KIND_ROWS) + first_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"] + first_page.parsed_rows = [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"]] + first_page.paging_state = b"has_more_pages" + + # Create second page of peers results without paging_state + second_page = ResultMessage(kind=RESULT_KIND_ROWS) + second_page.column_names = ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"] + second_page.parsed_rows = [["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]] + second_page.paging_state = None + + # Setup mock: first call returns first page, second call returns second page + mock_connection.wait_for_response.side_effect = [first_page, second_page] + + # Test _fetch_remaining_pages + self.control_connection._connection = mock_connection + query_msg = QueryMessage(query="SELECT * FROM system.peers", + consistency_level=ConsistencyLevel.ONE, + fetch_size=self.control_connection._schema_meta_page_size) + + result = _fetch_remaining_pages(mock_connection, query_msg, timeout=5) + + # Verify that both pages were fetched + assert len(result.parsed_rows) == 2 + assert result.parsed_rows[0][0] == "192.168.1.1" + assert result.parsed_rows[1][0] == "192.168.1.2" + assert result.paging_state is None + + # Verify wait_for_response was called twice (first page + second page) + assert mock_connection.wait_for_response.call_count == 2 + def test_refresh_nodes_and_tokens_with_invalid_peers(self): def refresh_and_validate_added_hosts(): self.connection.wait_for_responses = Mock(return_value=_node_meta_results(