diff --git a/py_hamt/store.py b/py_hamt/store.py index a35efb5..d69123d 100644 --- a/py_hamt/store.py +++ b/py_hamt/store.py @@ -120,6 +120,9 @@ class KuboCAS(ContentAddressedStore): the internally-created session. - **rpc_base_url / gateway_base_url** (str | None): override daemon endpoints (defaults match the local daemon ports). + - **gateway_base_urls** (list[str] | None): optional list of additional + gateway URLs to try in parallel when loading blocks. Each base URL is + normalized to end with ``/ipfs/``. ... """ @@ -137,6 +140,7 @@ def __init__( session: aiohttp.ClientSession | None = None, rpc_base_url: str | None = KUBO_DEFAULT_LOCAL_RPC_BASE_URL, gateway_base_url: str | None = KUBO_DEFAULT_LOCAL_GATEWAY_BASE_URL, + gateway_base_urls: list[str] | None = None, concurrency: int = 32, *, headers: dict[str, str] | None = None, @@ -188,7 +192,14 @@ def __init__( self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin=false" """@private""" - self.gateway_base_url: str = f"{gateway_base_url}/ipfs/" + + def _normalize(url: str) -> str: + """Ensure URL ends with '/ipfs/'.""" + return url.rstrip("/") + "/ipfs/" + + bases = gateway_base_urls if gateway_base_urls else [gateway_base_url] + self.gateway_base_urls = [_normalize(u) for u in bases] + self.gateway_base_url = self.gateway_base_urls[0] """@private""" self._session_per_loop: dict[ @@ -262,8 +273,31 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CI async def load(self, id: IPLDKind) -> bytes: """@private""" cid = cast(CID, id) # CID is definitely in the IPLDKind type - url: str = self.gateway_base_url + str(cid) - async with self._sem: # throttle gateway - async with self._loop_session().get(url) as resp: - resp.raise_for_status() - return await resp.read() + + async def _fetch(base: str) -> bytes: + url: str = base + str(cid) + async with self._sem: + async with self._loop_session().get(url) as resp: + resp.raise_for_status() + return await resp.read() + + if len(self.gateway_base_urls) == 1: + return await _fetch(self.gateway_base_urls[0]) + + tasks = [asyncio.create_task(_fetch(base)) for base in self.gateway_base_urls] + try: + for coro in asyncio.as_completed(tasks): + try: + result = await coro + except Exception: # keep racing + continue + else: + for t in tasks: + if not t.done(): + t.cancel() + return result + finally: + for t in tasks: + if not t.done(): + t.cancel() + raise RuntimeError("All gateway requests failed") diff --git a/tests/test_kubo_cas.py b/tests/test_kubo_cas.py index 54825b7..3b28656 100644 --- a/tests/test_kubo_cas.py +++ b/tests/test_kubo_cas.py @@ -1,10 +1,14 @@ -from typing import Literal, cast +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncIterator, Awaitable, Callable, Literal, cast import aiohttp import dag_cbor import pytest +from aiohttp import web from dag_cbor import IPLDKind from hypothesis import given, settings +from multiformats import CID from testing_utils import ipld_strategy # noqa from py_hamt import KuboCAS @@ -144,3 +148,66 @@ async def test_kubo_cas(create_ipfs, data: IPLDKind): # noqa cid = await kubo_cas.save(dag_cbor.encode(data), codec=codec_typed) result = dag_cbor.decode(await kubo_cas.load(cid)) assert data == result + + +@pytest.mark.ipfs +@pytest.mark.asyncio(loop_scope="session") +async def test_kubo_multi_gateway(create_ipfs, global_client_session): + """Verify that multiple gateway URLs work.""" + rpc_url, gateway_url = create_ipfs + + async with KuboCAS( + rpc_base_url=rpc_url, + gateway_base_url=gateway_url, + gateway_base_urls=[gateway_url, gateway_url], + session=global_client_session, + ) as kubo_cas: + cid = await kubo_cas.save(b"hello", codec="raw") + result = await kubo_cas.load(cid) + assert result == b"hello" + + +@asynccontextmanager +async def _run_server( + handler: Callable[[web.Request], Awaitable[web.StreamResponse]], +) -> AsyncIterator[str]: + app = web.Application() + app.router.add_get("/{tail:.*}", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] + try: + yield f"http://127.0.0.1:{port}" + finally: + await runner.cleanup() + + +@pytest.mark.asyncio +async def test_gateway_race_has_fallback(): + async def fail(request: web.Request) -> web.Response: + raise web.HTTPInternalServerError() + + async def ok(request: web.Request) -> web.Response: + await asyncio.sleep(0.05) + return web.Response(body=b"ok") + + async with _run_server(fail) as bad, _run_server(ok) as good: + cid = CID.decode("bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku") + async with KuboCAS(gateway_base_url=good, gateway_base_urls=[bad, good]) as cas: + assert await cas.load(cid) == b"ok" + + +@pytest.mark.asyncio +async def test_gateway_race_all_fail(): + async def fail(request: web.Request) -> web.Response: + raise web.HTTPInternalServerError() + + async with _run_server(fail) as bad1, _run_server(fail) as bad2: + cid = CID.decode("bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku") + async with KuboCAS( + gateway_base_url=bad1, gateway_base_urls=[bad1, bad2] + ) as cas: + with pytest.raises(RuntimeError, match="All gateway requests failed"): + await cas.load(cid)