From f48928ae4158409064ca86d09d84e7468c965a57 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Tue, 12 Aug 2025 18:00:25 +0200 Subject: [PATCH 01/35] feat: add msc3911 config option --- synapse/config/experimental.py | 8 +++ synapse/rest/client/versions.py | 4 ++ synapse/rest/media/create_resource.py | 12 +++- synapse/rest/media/upload_resource.py | 17 ++++++ tests/rest/client/test_media.py | 86 ++++++++++++++++++++++++++- 5 files changed, 125 insertions(+), 2 deletions(-) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 97d4cd9d5f..6af05960f5 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -588,3 +588,11 @@ def read_config( # MSC4306: Thread Subscriptions # (and MSC4308: sliding sync extension for thread subscriptions) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) + + # MSC3911: Linking Media to Events + self.msc3911_enabled: bool = experimental.get("msc3911_enabled", False) + + # Disable the current media create and upload endpoints + self.msc3911_unrestricted_media_upload_disabled: bool = experimental.get( + "msc3911_unrestricted_media_upload_disabled", False + ) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index fa39eb9e6d..49944843ac 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -177,6 +177,10 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: "uk.tcpip.msc4133": self.config.experimental.msc4133_enabled, # MSC4155: Invite filtering "org.matrix.msc4155": self.config.experimental.msc4155_enabled, + # MSC3911: Linking Media to Events + "org.matrix.msc3911": self.config.experimental.msc3911_enabled, + # MSC3911: Unrestricted Media Upload + "org.matrix.msc3911.unrestricted_media_upload_disabled": self.config.experimental.msc3911_unrestricted_media_upload_disabled, }, }, ) diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py index e45df11c9f..e2395ee1d3 100644 --- a/synapse/rest/media/create_resource.py +++ b/synapse/rest/media/create_resource.py @@ -23,7 +23,7 @@ import re from typing import TYPE_CHECKING -from synapse.api.errors import LimitExceededError +from synapse.api.errors import Codes, LimitExceededError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.http.server import respond_with_json from synapse.http.servlet import RestServlet @@ -53,8 +53,18 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): clock=self.clock, cfg=hs.config.ratelimiting.rc_media_create, ) + self.msc3911_unrestricted_media_upload_disabled = ( + hs.config.experimental.msc3911_unrestricted_media_upload_disabled + ) async def on_POST(self, request: SynapseRequest) -> None: + if self.msc3911_unrestricted_media_upload_disabled: + raise SynapseError( + 403, + "Unrestricted media creation is disabled", + errcode=Codes.FORBIDDEN, + ) + requester = await self.auth.get_user_by_req(request) # If the create media requests for the user are over the limit, drop them. diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 74d8280582..9e740f724f 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -53,6 +53,9 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): self._media_repository_callbacks = ( hs.get_module_api_callbacks().media_repository ) + self.msc3911_unrestricted_media_upload_disabled = ( + hs.config.experimental.msc3911_unrestricted_media_upload_disabled + ) async def _get_file_metadata( self, request: SynapseRequest, user_id: str @@ -113,6 +116,13 @@ class UploadServlet(BaseUploadServlet): PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")] async def on_POST(self, request: SynapseRequest) -> None: + if self.msc3911_unrestricted_media_upload_disabled: + raise SynapseError( + 403, + "Unrestricted media upload is disabled", + errcode=Codes.FORBIDDEN, + ) + requester = await self.auth.get_user_by_req(request) content_length, upload_name, media_type = await self._get_file_metadata( request, requester.user.to_string() @@ -145,6 +155,13 @@ class AsyncUploadServlet(BaseUploadServlet): async def on_PUT( self, request: SynapseRequest, server_name: str, media_id: str ) -> None: + if self.msc3911_unrestricted_media_upload_disabled: + raise SynapseError( + 403, + "Unrestricted media upload is disabled", + errcode=Codes.FORBIDDEN, + ) + requester = await self.auth.get_user_by_req(request) if server_name != self.server_name: diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 7aa1f2406c..016567a9d9 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -44,7 +44,7 @@ from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource -from synapse.api.errors import HttpResponseException +from synapse.api.errors import Codes, HttpResponseException from synapse.api.ratelimiting import Ratelimiter from synapse.config.oembed import OEmbedEndpointConfig from synapse.http.client import MultipartResponse @@ -2967,3 +2967,87 @@ def test_over_weekly_limit(self) -> None: # This will succeed as the weekly limit has reset channel = self.upload_media(900) self.assertEqual(channel.code, 200) + + +class UnrestrictedMediaUploadTestCase(unittest.HomeserverTestCase): + """ + This test case simulates a homeserver with media create and upload endpoints are + limited when `msc3911_unrestricted_media_upload_disabled` is configured to be True. + """ + + extra_config = { + "experimental_features": {"msc3911_unrestricted_media_upload_disabled": True} + } + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config.update(self.extra_config) + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.media_repo = hs.get_media_repository_resource() + self.register_user("testuser", "testpass") + self.tok = self.login("testuser", "testpass") + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def test_disable_unrestricted_media_upload_post(self) -> None: + """ + Tests that the upload servlet raises an error when unrestricted media upload is disabled. + """ + channel = self.make_request( + "POST", + "/_matrix/media/v3/upload?filename=test_png_upload", + SMALL_PNG, + access_token=self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(67))], + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual( + channel.json_body["error"], "Unrestricted media upload is disabled" + ) + + def test_disable_unrestricted_media_upload_put(self) -> None: + """ + Tests that the upload servlet raises an error when unrestricted media upload is disabled. + """ + channel = self.make_request( + "PUT", + f"/_matrix/media/v3/upload/{self.hs.hostname}/test_png_upload", + content=b"dummy file content", + content_type=b"image/png", + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual( + channel.json_body["error"], "Unrestricted media upload is disabled" + ) + + def test_disable_unrestricted_media_upload_create(self) -> None: + """ + Tests that the create servlet raises an error when unrestricted media upload is disabled. + """ + channel = self.make_request( + "POST", + f"/_matrix/media/v1/create/{self.hs.hostname}/test_png_upload", + content=b"dummy file content", + content_type=b"image/png", + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) + self.assertEqual( + channel.json_body["error"], "Unrestricted media creation is disabled" + ) From 0120afbafd7b56e6e4c4d26250c1fe0c69798e9a Mon Sep 17 00:00:00 2001 From: Jason Little Date: Tue, 12 Aug 2025 05:39:33 -0500 Subject: [PATCH 02/35] feat: Introduce preliminary database infrastructure for storage and retrieval of media restrictions --- synapse/media/media_repository.py | 16 +- .../databases/main/media_repository.py | 138 +++++++++++- synapse/storage/schema/__init__.py | 8 +- .../delta/93/01_media_attachment_tables.sql | 9 + .../93/01_media_restrictions.sql.postgres | 7 + .../delta/93/01_media_restrictions.sql.sqlite | 9 + .../02_media_attachment_tables.sql.postgres | 7 + tests/rest/client/utils.py | 27 +++ tests/storage/test_media.py | 210 ++++++++++++++++++ 9 files changed, 424 insertions(+), 7 deletions(-) create mode 100644 synapse/storage/schema/main/delta/93/01_media_attachment_tables.sql create mode 100644 synapse/storage/schema/main/delta/93/01_media_restrictions.sql.postgres create mode 100644 synapse/storage/schema/main/delta/93/01_media_restrictions.sql.sqlite create mode 100644 synapse/storage/schema/main/delta/93/02_media_attachment_tables.sql.postgres create mode 100644 tests/storage/test_media.py diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index d7259176e7..b36e85b152 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -218,12 +218,15 @@ def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> N self.recently_accessed_locals.add(media_id) @trace - async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]: + async def create_media_id( + self, auth_user: UserID, restricted: bool = False + ) -> Tuple[str, int]: """Create and store a media ID for a local user and return the MXC URI and its expiration. Args: auth_user: The user_id of the uploader + restricted: If this is to be considered restricted media Returns: A tuple containing the MXC URI of the stored content and the timestamp at @@ -235,6 +238,7 @@ async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]: media_id=media_id, time_now_ms=now, user_id=auth_user, + restricted=restricted, ) return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time @@ -300,6 +304,7 @@ async def create_or_update_content( content_length: int, auth_user: UserID, media_id: Optional[str] = None, + restricted: bool = False, ) -> MXCUri: """Create or update the content of the given media ID. @@ -311,6 +316,7 @@ async def create_or_update_content( auth_user: The user_id of the uploader media_id: The media ID to update if provided, otherwise creates new media ID. + restricted: Boolean for if the media is restricted per msc3911 Returns: The mxc url of the stored content @@ -373,6 +379,7 @@ async def create_or_update_content( user_id=auth_user, sha256=sha256, quarantined_by="system" if should_quarantine else None, + restricted=restricted, ) else: await self.store.update_local_media( @@ -878,6 +885,10 @@ async def _download_remote_file( quarantined_by=None, authenticated=authenticated, sha256=sha256writer.hexdigest(), + # The "pre-msc3916" method for downloading over federation, restricted + # will always be false and attachments will always be None here + restricted=False, + attachments=None, ) async def _federation_download_remote_file( @@ -1011,6 +1022,9 @@ async def _federation_download_remote_file( quarantined_by=None, authenticated=authenticated, sha256=sha256writer.hexdigest(), + # Update this when the federation responses are updated + restricted=False, + attachments=None, ) def _get_thumbnail_requirements( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index f726846e57..8b4be62b3c 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -19,8 +19,10 @@ # [This file includes modifications made by New Vector Limited] # # +import json import logging from enum import Enum +from http import HTTPStatus from typing import ( TYPE_CHECKING, Collection, @@ -35,6 +37,7 @@ import attr from synapse.api.constants import Direction +from synapse.api.errors import Codes, SynapseError from synapse.logging.opentracing import trace from synapse.media._base import ThumbnailInfo from synapse.storage._base import SQLBaseStore @@ -55,6 +58,23 @@ logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True, kw_only=True) +class MediaRestrictions: + """ + Per MSC3911: Media can be restricted by either 'event_id' or 'profile_user_id' but + never both. Having neither represents an unknown restriction exists + + This needs a validator of some sort + + Attributes: + event_id: + profile_user_id: + """ + + event_id: Optional[str] = None + profile_user_id: Optional[UserID] = None + + @attr.s(slots=True, frozen=True, auto_attribs=True) class LocalMedia: media_id: str @@ -69,6 +89,8 @@ class LocalMedia: user_id: Optional[str] authenticated: Optional[bool] sha256: Optional[str] + restricted: bool + attachments: Optional[MediaRestrictions] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -84,6 +106,8 @@ class RemoteMedia: quarantined_by: Optional[str] authenticated: Optional[bool] sha256: Optional[str] + restricted: bool + attachments: Optional[MediaRestrictions] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -247,12 +271,19 @@ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]: "user_id", "authenticated", "sha256", + "restricted", ), allow_none=True, desc="get_local_media", ) if row is None: return None + restriction_info = None + if bool(row[11]): + restriction_info = await self.get_media_restrictions( + self.server_name, media_id + ) + return LocalMedia( media_id=media_id, media_type=row[0], @@ -266,6 +297,8 @@ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]: user_id=row[8], authenticated=row[9], sha256=row[10], + restricted=bool(row[11]), + attachments=restriction_info, ) async def get_local_media_by_user_paginate( @@ -312,7 +345,7 @@ def get_local_media_by_user_paginate_txn( sql = """ SELECT - media_id, + lmr.media_id, media_type, media_length, upload_name, @@ -323,17 +356,24 @@ def get_local_media_by_user_paginate_txn( safe_from_quarantine, user_id, authenticated, - sha256 - FROM local_media_repository + sha256, + restricted, + ma.restrictions_json->'restrictions'->>'event_id' AS event_id, + ma.restrictions_json->'restrictions'->>'event_id' AS profile_user_id + FROM local_media_repository AS lmr + -- a LEFT JOIN allows values from the right table to be NULL if non-existent + LEFT JOIN media_attachments AS ma ON lmr.media_id = ma.media_id + AND ma.server_name = ? WHERE user_id = ? - ORDER BY {order_by_column} {order}, media_id ASC + ORDER BY lmr.{order_by_column} {order}, lmr.media_id ASC LIMIT ? OFFSET ? """.format( order_by_column=order_by_column, order=order, ) - args += [limit, start] + # Reset the args, instead of playing prepend and append games + args = [self.server_name, user_id, limit, start] txn.execute(sql, args) media = [ LocalMedia( @@ -349,6 +389,12 @@ def get_local_media_by_user_paginate_txn( user_id=row[9], authenticated=row[10], sha256=row[11], + restricted=bool(row[12]), + attachments=MediaRestrictions( + event_id=row[13], profile_user_id=row[14] + ) + if bool(row[12]) + else None, ) for row in txn ] @@ -451,6 +497,7 @@ async def store_local_media_id( media_id: str, time_now_ms: int, user_id: UserID, + restricted: bool = False, ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -464,6 +511,7 @@ async def store_local_media_id( "created_ts": time_now_ms, "user_id": user_id.to_string(), "authenticated": authenticated, + "restricted": restricted, }, desc="store_local_media_id", ) @@ -480,6 +528,7 @@ async def store_local_media( url_cache: Optional[str] = None, sha256: Optional[str] = None, quarantined_by: Optional[str] = None, + restricted: bool = False, ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -499,6 +548,7 @@ async def store_local_media( "authenticated": authenticated, "sha256": sha256, "quarantined_by": quarantined_by, + "restricted": restricted, }, desc="store_local_media", ) @@ -699,12 +749,17 @@ async def get_cached_remote_media( "quarantined_by", "authenticated", "sha256", + "restricted", ), allow_none=True, desc="get_cached_remote_media", ) if row is None: return row + restriction_info = None + if row[9] is not None and row[9] is True: + restriction_info = await self.get_media_restrictions(origin, media_id) + return RemoteMedia( media_origin=origin, media_id=media_id, @@ -717,6 +772,8 @@ async def get_cached_remote_media( quarantined_by=row[6], authenticated=row[7], sha256=row[8], + restricted=bool(row[9]), + attachments=restriction_info, ) async def store_cached_remote_media( @@ -729,6 +786,7 @@ async def store_cached_remote_media( upload_name: Optional[str], filesystem_id: str, sha256: Optional[str], + restricted: bool = False, ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -748,6 +806,7 @@ async def store_cached_remote_media( "last_access_ts": time_now_ms, "authenticated": authenticated, "sha256": sha256, + "restricted": restricted, }, desc="store_cached_remote_media", ) @@ -1070,3 +1129,72 @@ def _get_media_uploaded_size_for_user_txn( "get_media_uploaded_size_for_user", _get_media_uploaded_size_for_user_txn, ) + + async def get_media_restrictions( + self, server_name: str, media_id: str + ) -> Optional[MediaRestrictions]: + """ + Retrieve the restrictions json on a given media_id. Will return None if the + media has not been attached to a restrictable reference yet. If all fields return None, + the restrictable reference is unknown. + + Currently supported are: + event_id + profile_user_id + """ + # The '->' and '->>' operators are compatible with both Postgres +9.5 and SQLite +3.38.0 + sql = """ + SELECT restrictions_json->'restrictions'->>'event_id' AS event_id, restrictions_json->'restrictions'->>'profile_user_id' AS profile_user_id + FROM media_attachments + WHERE server_name = ? AND media_id = ? + """ + args = [server_name, media_id] + row: List[Tuple[Optional[str], Optional[str]]] = await self.db_pool.execute( + "get_media_restrictions_v2", sql, *args + ) + + # There should only ever be a single row. This is enforced by the constraint on + # the table to only have a single row per server_name/media_id combo + if row: + event_id = row[0][0] + # Because the UserID object can be None, the 'to_string()' method may not exist + profile_user_id = UserID.from_string(row[0][1]) if row[0][1] else None + return MediaRestrictions(event_id=event_id, profile_user_id=profile_user_id) + + return None + + async def set_media_restrictions( + self, + server_name: str, + media_id: str, + media_restrictions_json: JsonDict, + ) -> None: + """ + Add the media restrictions to the database + + Args: + server_name: + media_id: + media_restrictions_json: The media restrictions as dict + + Raises: + SynapseError if the media already has restrictions on it + """ + try: + await self.db_pool.simple_insert( + "media_attachments", + { + "server_name": server_name, + "media_id": media_id, + "restrictions_json": json.dumps(media_restrictions_json), + }, + ) + except self.db_pool.engine.module.IntegrityError: + # For sqlite, a unique constraint violation is an integrity error. For + # psycopg2, a UniqueViolation is a subclass of IntegrityError, so this + # covers both. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"This media, '{media_id}' already has restrictions set.", + errcode=Codes.INVALID_PARAM, + ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 3c3b13437e..e601b53883 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 92 # remember to update the list below when updating +SCHEMA_VERSION = 93 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -168,6 +168,12 @@ Changes in SCHEMA_VERSION = 92 - Cleaned up a trigger that was added in #18260 and then reverted. + +Changes in SCHEMA_VERSION = 93 + - Introduced new column for `local_media_repository` and `remote_media_cache` to + flag media as 'restricted'. Defaults to False + - Introduced new table, `media_attachments`, to hold restriction data based on + `server_name` and `media_id`. """ diff --git a/synapse/storage/schema/main/delta/93/01_media_attachment_tables.sql b/synapse/storage/schema/main/delta/93/01_media_attachment_tables.sql new file mode 100644 index 0000000000..0ad260c455 --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_media_attachment_tables.sql @@ -0,0 +1,9 @@ +-- JSONB and the operators we use are compatible with both Postgres 9.5+ and SQLite 3.38.0+ +-- The UNIQUE constraint helps not only to insure there are never more than one grouping +-- of restrictions for a given server_name/media_id combo, but also act as an index +CREATE TABLE media_attachments ( + server_name TEXT NOT NULL, + media_id TEXT NOT NULL, + restrictions_json JSONB NOT NULL, + UNIQUE (server_name, media_id) +); diff --git a/synapse/storage/schema/main/delta/93/01_media_restrictions.sql.postgres b/synapse/storage/schema/main/delta/93/01_media_restrictions.sql.postgres new file mode 100644 index 0000000000..7591c90e0f --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_media_restrictions.sql.postgres @@ -0,0 +1,7 @@ +ALTER TABLE local_media_repository +ADD COLUMN restricted BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TABLE remote_media_cache +ADD COLUMN restricted BOOLEAN NOT NULL DEFAULT FALSE; + +-- by using DEFAULT FALSE, the existing data should be backwards compatible diff --git a/synapse/storage/schema/main/delta/93/01_media_restrictions.sql.sqlite b/synapse/storage/schema/main/delta/93/01_media_restrictions.sql.sqlite new file mode 100644 index 0000000000..791daa8b9e --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_media_restrictions.sql.sqlite @@ -0,0 +1,9 @@ +ALTER TABLE local_media_repository +ADD COLUMN restricted INTEGER DEFAULT 0 CHECK (restricted IN (0, 1)); + +ALTER TABLE remote_media_cache +ADD COLUMN restricted INTEGER DEFAULT 0 CHECK (restricted IN (0, 1)); + + -- sqlite doesn't do booleans, integers is what you get + -- by using DEFAULT 0, the existing data should be backwards compatible + -- the CHECK enforces the data(equivalent of NOT NULL) diff --git a/synapse/storage/schema/main/delta/93/02_media_attachment_tables.sql.postgres b/synapse/storage/schema/main/delta/93/02_media_attachment_tables.sql.postgres new file mode 100644 index 0000000000..aefd9d6a5c --- /dev/null +++ b/synapse/storage/schema/main/delta/93/02_media_attachment_tables.sql.postgres @@ -0,0 +1,7 @@ +-- Postgres supports Generalized Inverted Indexes, which work well for JSON data. +CREATE INDEX media_attachments_event_id_idx ON media_attachments USING GIN ((restrictions_json->'restrictions'->'event_id')); + +CREATE INDEX media_attachments_profile_user_id_idx ON media_attachments USING GIN ((restrictions_json->'restrictions'->'profile_user_id')); + +-- Unfortunately, SQLite does not support the same sort of index. The alternative would +-- probably be an index on a generated column produced from the JSON at insertion time. diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 280486da08..dacc5639e3 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -607,6 +607,33 @@ def upload_media( return channel.json_body + def create_media_id_v1( + self, + tok: str, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Create the media ID that can be uploaded to later + Args: + tok: The user token to use during the upload + expect_code: The return code to expect from attempting to upload the media + """ + path = "/_matrix/media/v1/create" + channel = make_request( + self.reactor, + self.site, + "POST", + path, + access_token=tok, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def whoami( self, access_token: str, diff --git a/tests/storage/test_media.py b/tests/storage/test_media.py new file mode 100644 index 0000000000..a3eaa3ce93 --- /dev/null +++ b/tests/storage/test_media.py @@ -0,0 +1,210 @@ +import io +from typing import Dict + +from matrix_common.types.mxc_uri import MXCUri + +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource + +from synapse.api.errors import SynapseError +from synapse.rest import admin +from synapse.rest.client import login, media +from synapse.server import HomeServer +from synapse.storage.databases.main.media_repository import MediaRestrictions +from synapse.types import JsonDict, UserID +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.test_utils import SMALL_PNG + + +class MediaAttachmentStorageTestCase(unittest.HomeserverTestCase): + """Test that storing and retrieving media restrictions works as expected""" + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + + def test_store_and_retrieve_media_restrictions_by_event_id(self) -> None: + event_id = "$random_event_id" + media_restrictions = {"restrictions": {"event_id": event_id}} + media_id = random_string(24) + self.get_success_or_raise( + self.store.set_media_restrictions( + self.server_name, media_id, media_restrictions + ) + ) + + retrieved_restrictions = self.get_success_or_raise( + self.store.get_media_restrictions(self.server_name, media_id) + ) + assert retrieved_restrictions is not None + assert retrieved_restrictions.event_id == event_id + assert retrieved_restrictions.profile_user_id is None + + def test_store_and_retrieve_media_restrictions_by_profile_user_id(self) -> None: + user_id = UserID.from_string("@frank:test") + media_restrictions = {"restrictions": {"profile_user_id": user_id.to_string()}} + media_id = random_string(24) + self.get_success_or_raise( + self.store.set_media_restrictions( + self.server_name, media_id, media_restrictions + ) + ) + + retrieved_restrictions = self.get_success_or_raise( + self.store.get_media_restrictions(self.server_name, media_id) + ) + assert retrieved_restrictions is not None + assert retrieved_restrictions.event_id is None + assert retrieved_restrictions.profile_user_id == user_id + + def test_retrieve_media_without_restrictions(self) -> None: + media_id = random_string(24) + + retrieved_restrictions = self.get_success_or_raise( + self.store.get_media_restrictions(self.server_name, media_id) + ) + assert retrieved_restrictions is None + + +class MediaPendingAttachmentTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + return config + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + + self.user = self.register_user("frank", "password") + self.tok = self.login("frank", "password") + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + # The old endpoints are not loaded with the register_servlets above + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def test_setting_media_restriction_twice_errors( + self, + ) -> None: + """Setting media restrictions on a single piece of media TWICE is not allowed. + Test that it errors + """ + upload_result = self.helper.upload_media(SMALL_PNG, tok=self.tok) + assert upload_result.get("content_uri") is not None + + content_uri: str = upload_result["content_uri"] + # We can split the content_uri on the last "/" and the rest is the media_id + media_id = content_uri.rsplit("/", maxsplit=1)[1] + + event_id = "$something_hashy_doesnt_matter" + media_restrictions = {"restrictions": {"event_id": event_id}} + self.get_success( + self.store.set_media_restrictions( + self.server_name, media_id, media_restrictions + ) + ) + + existing_media_restrictions = self.get_success( + self.store.get_media_restrictions( + self.server_name, + media_id, + ) + ) + assert existing_media_restrictions is not None + + self.get_failure( + self.store.set_media_restrictions( + self.server_name, media_id, media_restrictions + ), + SynapseError, + ) + + +class MediaAttachmentFlowTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + + self.user = self.register_user("frank", "password") + self.tok = self.login("frank", "password") + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + # The old endpoints are not loaded with the register_servlets above + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def create_media(self) -> MXCUri: + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + restricted=True, + ) + ) + return content_uri + + def test_flow(self) -> None: + """Example flow of storing media data and retrieving it from the database""" + # Create media by using create_or_update_content() helper. This will likely be + # on the new `/create` and `/upload` endpoints for msc3911. + + # set actual restrictions using storage method `set_media_restrictions()` + + # use `get_local_media()` to retrieve the data + + mxc_uri = self.create_media() + media_id = mxc_uri.media_id + assert media_id + + local_media_object = self.get_success(self.store.get_local_media(media_id)) + assert local_media_object + assert local_media_object.restricted is True + + # This one is why we are here, it doesn't exist yet + assert local_media_object.attachments is None + + event_id = "$event_id_hash_goes_here" + self.get_success( + self.store.set_media_restrictions( + self.server_name, + media_id, + {"restrictions": {"event_id": event_id}}, + ) + ) + + # Retrieve the data and make sure the restrictions are there + local_media_object = self.get_success(self.store.get_local_media(media_id)) + assert local_media_object + + assert local_media_object.restricted is True + # This one is why we are here, it's here this time. Yay! + assert isinstance(local_media_object.attachments, MediaRestrictions) + assert local_media_object.attachments.event_id == event_id From c5e1337a4c16e3f6ecacf542da29551d29218804 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Tue, 12 Aug 2025 18:00:25 +0200 Subject: [PATCH 03/35] feat: add new media create and upload endpoints --- synapse/media/media_repository.py | 17 ++- synapse/rest/client/media.py | 5 + synapse/rest/media/create_resource.py | 19 ++- synapse/rest/media/upload_resource.py | 47 ++++++- tests/rest/client/test_media.py | 183 ++++++++++++++++++++++---- 5 files changed, 228 insertions(+), 43 deletions(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index b36e85b152..c2e04ea767 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -29,7 +29,6 @@ import attr from matrix_common.types.mxc_uri import MXCUri -import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred @@ -69,7 +68,7 @@ from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia -from synapse.types import UserID +from synapse.types import Requester, UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -494,6 +493,20 @@ async def get_local_media( if media_info.authenticated: raise NotFoundError() + # MSC3911: If media is restricted but restriction is empty, the media is in + # pending state and only creator can see it until it is attached to an event. + if media_info.restricted: + restrictions = await self.store.get_media_restrictions( + self.server_name, media_info.media_id + ) + if not restrictions: + if not ( + isinstance(request.requester, Requester) + and request.requester.user.to_string() == media_info.user_id + ): + respond_404(request) + return + self.mark_recently_accessed(None, media_id) # Once we've checked auth we can return early if the media is cached on diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 4c044ae900..5ff0178416 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -41,6 +41,8 @@ from synapse.media.media_repository import MediaRepository from synapse.media.media_storage import MediaStorage from synapse.media.thumbnailer import ThumbnailProvider +from synapse.rest.media.create_resource import CreateResource +from synapse.rest.media.upload_resource import UploadRestrictedResource from synapse.server import HomeServer from synapse.util.stringutils import parse_and_validate_server_name @@ -284,3 +286,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: MediaConfigResource(hs).register(http_server) ThumbnailResource(hs, media_repo, media_repo.media_storage).register(http_server) DownloadResource(hs, media_repo).register(http_server) + if hs.config.experimental.msc3911_enabled: + CreateResource(hs, media_repo, restricted=True).register(http_server) + UploadRestrictedResource(hs, media_repo).register(http_server) diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py index e2395ee1d3..e863310635 100644 --- a/synapse/rest/media/create_resource.py +++ b/synapse/rest/media/create_resource.py @@ -37,9 +37,9 @@ class CreateResource(RestServlet): - PATTERNS = [re.compile("/_matrix/media/v1/create")] - - def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): + def __init__( + self, hs: "HomeServer", media_repo: "MediaRepository", restricted: bool = False + ): super().__init__() self.media_repo = media_repo @@ -53,12 +53,21 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): clock=self.clock, cfg=hs.config.ratelimiting.rc_media_create, ) + # MSC3911: If this is enabled, this endpoint will not allow media creation,which is unrestricted. self.msc3911_unrestricted_media_upload_disabled = ( hs.config.experimental.msc3911_unrestricted_media_upload_disabled ) + self.restricted = restricted + + if self.restricted: + self.PATTERNS = [ + re.compile("/_matrix/client/unstable/org.matrix.msc3911/media/create") + ] + else: + self.PATTERNS = [re.compile("/_matrix/media/v1/create")] async def on_POST(self, request: SynapseRequest) -> None: - if self.msc3911_unrestricted_media_upload_disabled: + if not self.restricted and self.msc3911_unrestricted_media_upload_disabled: raise SynapseError( 403, "Unrestricted media creation is disabled", @@ -81,7 +90,7 @@ async def on_POST(self, request: SynapseRequest) -> None: ) content_uri, unused_expires_at = await self.media_repo.create_media_id( - requester.user + requester.user, restricted=self.restricted ) logger.info( diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 9e740f724f..97e478cb86 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -53,6 +53,7 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): self._media_repository_callbacks = ( hs.get_module_api_callbacks().media_repository ) + # MSC3911: If this is enabled, this endpoint will not allow unrestricted media uploads. self.msc3911_unrestricted_media_upload_disabled = ( hs.config.experimental.msc3911_unrestricted_media_upload_disabled ) @@ -145,6 +146,45 @@ async def on_POST(self, request: SynapseRequest) -> None: ) +class UploadRestrictedResource(BaseUploadServlet): + """ + MSC3911: This is an unstable API endpoint for uploading "restricted" media. This is + equivalent to the existing upload endpoint `/_matrix/media/v3/upload$`, but media + is not initially viewable except to the user that uploaded it, until it is attached + to an event. + """ + + PATTERNS = [re.compile("/_matrix/client/unstable/org.matrix.msc3911/media/upload$")] + + async def on_POST(self, request: SynapseRequest) -> None: + requester = await self.auth.get_user_by_req(request) + content_length, upload_name, media_type = await self._get_file_metadata( + request, requester.user.to_string() + ) + try: + content: IO = request.content # type: ignore + content_uri = await self.media_repo.create_or_update_content( + media_type, + upload_name, + content, + content_length, + requester.user, + restricted=True, # Media is marked as restricted. + ) + except SpamMediaException: + # For uploading of media we want to respond with a 400, instead of + # the default 404, as that would just be confusing. + raise SynapseError(400, "Bad content") + except Exception as e: + logger.error("Failed to upload media: %s", e) + raise SynapseError(500, "Failed to upload media") + + logger.info("Uploaded content with URI '%s'", content_uri) + respond_with_json( + request, 200, {"content_uri": str(content_uri)}, send_cors=True + ) + + class AsyncUploadServlet(BaseUploadServlet): PATTERNS = [ re.compile( @@ -155,13 +195,6 @@ class AsyncUploadServlet(BaseUploadServlet): async def on_PUT( self, request: SynapseRequest, server_name: str, media_id: str ) -> None: - if self.msc3911_unrestricted_media_upload_disabled: - raise SynapseError( - 403, - "Unrestricted media upload is disabled", - errcode=Codes.FORBIDDEN, - ) - requester = await self.auth.get_user_by_req(request) if server_name != self.server_name: diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 016567a9d9..ba38b56041 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -2969,7 +2969,7 @@ def test_over_weekly_limit(self) -> None: self.assertEqual(channel.code, 200) -class UnrestrictedMediaUploadTestCase(unittest.HomeserverTestCase): +class DisableUnrestrictedResourceTestCase(unittest.HomeserverTestCase): """ This test case simulates a homeserver with media create and upload endpoints are limited when `msc3911_unrestricted_media_upload_disabled` is configured to be True. @@ -2980,8 +2980,6 @@ class UnrestrictedMediaUploadTestCase(unittest.HomeserverTestCase): } servlets = [ media.register_servlets, - login.register_servlets, - admin.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -2991,43 +2989,38 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() - self.register_user("testuser", "testpass") - self.tok = self.login("testuser", "testpass") def create_resource_dict(self) -> dict[str, Resource]: resources = super().create_resource_dict() resources["/_matrix/media"] = self.hs.get_media_repository_resource() return resources - def test_disable_unrestricted_media_upload_post(self) -> None: + def test_unrestricted_resource_creation_disabled(self) -> None: """ - Tests that the upload servlet raises an error when unrestricted media upload is disabled. + Tests that CreateResource raises an error when + `msc3911_unrestricted_media_upload_disabled` is True. """ channel = self.make_request( "POST", - "/_matrix/media/v3/upload?filename=test_png_upload", - SMALL_PNG, - access_token=self.tok, - shorthand=False, - content_type=b"image/png", - custom_headers=[("Content-Length", str(67))], + "/_matrix/media/v1/create", ) self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual( - channel.json_body["error"], "Unrestricted media upload is disabled" + channel.json_body["error"], "Unrestricted media creation is disabled" ) - def test_disable_unrestricted_media_upload_put(self) -> None: + def test_unrestricted_resource_upload_disabled(self) -> None: """ - Tests that the upload servlet raises an error when unrestricted media upload is disabled. + Tests that UploadServlet raises an error when + `msc3911_unrestricted_media_upload_disabled` is True. """ channel = self.make_request( - "PUT", - f"/_matrix/media/v3/upload/{self.hs.hostname}/test_png_upload", - content=b"dummy file content", + "POST", + "/_matrix/media/v3/upload?filename=test_png_upload", + content=SMALL_PNG, content_type=b"image/png", - access_token=self.tok, + custom_headers=[("Content-Length", str(67))], ) self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) @@ -3035,19 +3028,151 @@ def test_disable_unrestricted_media_upload_put(self) -> None: channel.json_body["error"], "Unrestricted media upload is disabled" ) - def test_disable_unrestricted_media_upload_create(self) -> None: + +class RestrictedResourceTestCase(unittest.HomeserverTestCase): + """ + Tests restricted media creation and upload endpoints when `msc3911_enabled` is + configured to be True. + """ + + extra_config = { + "experimental_features": {"msc3911_enabled": True}, + } + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config.update(self.extra_config) + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.media_repo = hs.get_media_repository_resource() + self.register_user("creator", "testpass") + self.creator_tok = self.login("creator", "testpass") + + self.register_user("random_user", "testpass") + self.other_user_tok = self.login("random_user", "testpass") + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def test_create_restricted_resource(self) -> None: + """ + Tests that the new create endpoint creates a restricted resource. + """ + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc3911/media/create", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self.assertIn("content_uri", channel.json_body) + self.assertIn("unused_expires_at", channel.json_body) + media_id = channel.json_body["content_uri"].split("/")[-1] + + # Check the `restricted` field is True. + media = self.get_success( + self.hs.get_datastores().main.get_local_media(media_id) + ) + assert media is not None + self.assertEqual(media.media_id, media_id) + self.assertTrue(media.restricted) + + def test_upload_restricted_resource(self) -> None: """ - Tests that the create servlet raises an error when unrestricted media upload is disabled. + Tests that the new upload endpoints uploads a restricted resource. """ channel = self.make_request( "POST", - f"/_matrix/media/v1/create/{self.hs.hostname}/test_png_upload", - content=b"dummy file content", + "/_matrix/client/unstable/org.matrix.msc3911/media/upload?filename=test_png_upload", + content=SMALL_PNG, content_type=b"image/png", - access_token=self.tok, + access_token=self.creator_tok, + custom_headers=[("Content-Length", str(67))], ) - self.assertEqual(channel.code, 403) - self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) - self.assertEqual( - channel.json_body["error"], "Unrestricted media creation is disabled" + self.assertEqual(channel.code, 200) + self.assertIn("content_uri", channel.json_body) + media_id = channel.json_body["content_uri"].split("/")[-1] + + # Check the `restricted` field is True. + media = self.get_success( + self.hs.get_datastores().main.get_local_media(media_id) + ) + assert media is not None + self.assertEqual(media.media_id, media_id) + self.assertTrue(media.restricted) + + # The media is not attached to any event yet, only creator can see it. + # The creator can download the restricted resource. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", + access_token=self.creator_tok, + ) + assert channel.code == 200 + + # The other user cannot download the restricted resource in pending state. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", + access_token=self.other_user_tok, + ) + assert channel.code == 404 + + def test_async_upload_restricted_resource(self) -> None: + """ + Tests the combination of new create endpoint and the existing async upload + endpoint uploads a restricted resource. + """ + # Create media with new endpoint. + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc3911/media/create", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self.assertIn("content_uri", channel.json_body) + self.assertIn("unused_expires_at", channel.json_body) + media_id = channel.json_body["content_uri"].split("/")[-1] + + # Async upload with existing endpoint. + channel = self.make_request( + "PUT", + f"/_matrix/media/v3/upload/{self.hs.hostname}/{media_id}", + content=SMALL_PNG, + content_type=b"image/png", + access_token=self.creator_tok, + custom_headers=[("Content-Length", str(67))], + ) + self.assertEqual(channel.code, 200) + + # Check the `restricted` field is True. + media = self.get_success( + self.hs.get_datastores().main.get_local_media(media_id) + ) + assert media is not None + self.assertEqual(media.media_id, media_id) + self.assertTrue(media.restricted) + + # Media is not attached to any event yet, only creator can see it. + # The creator can download the restricted resource. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", + access_token=self.creator_tok, + ) + assert channel.code == 200 + + # The other user cannot download the restricted resource. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", + access_token=self.other_user_tok, ) + assert channel.code == 404 From 955928e69225e34e9d95f5afb0ce1b4f0590ef29 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Fri, 22 Aug 2025 11:51:29 +0200 Subject: [PATCH 04/35] chore: update workflow for nightly image --- .github/workflows/docker-famedly.yml | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docker-famedly.yml b/.github/workflows/docker-famedly.yml index 29e6af85f4..4ef4dd70a0 100644 --- a/.github/workflows/docker-famedly.yml +++ b/.github/workflows/docker-famedly.yml @@ -7,6 +7,13 @@ name: Docker on: push: tags: ["v*.*.*_*"] + workflow_dispatch: + # Manually trigger to build and push docker image with modules to nightly Harbor registry. + inputs: + tag: + description: 'Provide tag name with work package. Tag must only contain ASCII letters, digits, underscores, periods, or dashes.' + required: true + type: string concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -59,6 +66,18 @@ jobs: outputs: build_matrix: ${{ steps.get-matrix.outputs.build_matrix }} + validate_image_tag: + # Validate the tag input for workflow_dispatch. + if: ${{ github.event_name == 'workflow_dispatch' }} + runs-on: ubuntu-latest + steps: + - name: Validate Image Tag + run: | + if ! [[ "${{ github.event.inputs.tag }}" =~ ^[a-zA-Z0-9._-]+$ ]]; then + echo "Error: tag must only contain ASCII letters, digits, underscores, periods, or dashes." + exit 1 + fi + production-build: if: ${{ !cancelled() && !failure() }} # Allow for stopping the build job needs: @@ -70,7 +89,7 @@ jobs: with: push: ${{ github.event_name != 'pull_request' }} # Always build, don't publish on pull requests registry_user: ${{ vars.REGISTRY_USER }} - registry: registry.famedly.net/docker-oss + registry: ${{ github.event_name == 'workflow_dispatch' && 'registry.famedly.net/docker-nightly' || 'registry.famedly.net/docker-oss' }} image_name: synapse file: docker/Dockerfile-famedly # Notice that there is a leading 'sha-' in front of the actual sha, as that is @@ -82,5 +101,6 @@ jobs: # Tag the production image used for famedly deployments. tags: | type=ref,event=tag,suffix=-${{ matrix.job.mod_pack_name }} + type=raw,enable=${{ github.event_name == 'workflow_dispatch' }},value=${{ matrix.job.mod_pack_name }}-${{ github.event.inputs.tag }} flavor: latest=false secrets: inherit From 5529009d8933ee6de677a21bb828ac33821f6a7f Mon Sep 17 00:00:00 2001 From: Jason Little Date: Mon, 25 Aug 2025 12:45:49 -0500 Subject: [PATCH 05/35] feat: msc3911[AP3] - Update methods for sending events to allow attaching media --- rust/src/events/internal_metadata.rs | 20 + synapse/handlers/message.py | 89 ++++ synapse/handlers/room_member.py | 23 + synapse/rest/client/room.py | 115 ++++- synapse/synapse_rust/events.pyi | 2 + tests/rest/client/test_rooms.py | 609 +++++++++++++++++++++++++++ 6 files changed, 855 insertions(+), 3 deletions(-) diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index eeb6074c10..ecfda7bbc1 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -58,12 +58,17 @@ enum EventInternalMetadataData { TxnId(Box), TokenId(i64), DeviceId(Box), + MediaReferences(Vec), } impl EventInternalMetadataData { /// Convert the field to its name and python object. fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, Bound<'a, PyAny>) { match self { + EventInternalMetadataData::MediaReferences(o) => ( + pyo3::intern!(py, "media_references"), + o.into_pyobject(py).unwrap().into_any(), + ), EventInternalMetadataData::OutOfBandMembership(o) => ( pyo3::intern!(py, "out_of_band_membership"), o.into_pyobject(py) @@ -128,6 +133,11 @@ impl EventInternalMetadataData { let key_str: PyBackedStr = key.extract()?; let e = match &*key_str { + "media_references" => EventInternalMetadataData::MediaReferences( + value + .extract() + .with_context(|| format!("'{key_str}' has invalid type"))?, + ), "out_of_band_membership" => EventInternalMetadataData::OutOfBandMembership( value .extract() @@ -469,4 +479,14 @@ impl EventInternalMetadata { fn set_device_id(&mut self, obj: String) { set_property!(self, DeviceId, obj.into_boxed_str()); } + + /// The media references for the restrictions being set for this event, if any. + #[getter] + fn get_media_references(&self) -> Option<&Vec> { + get_property_opt!(self, MediaReferences) + } + #[setter] + fn set_media_references(&mut self, obj: Vec) { + set_property!(self, MediaReferences, obj); + } } diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index cb64df2d01..da1a5724de 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Tuple from canonicaljson import encode_canonical_json +from matrix_common.types.mxc_uri import MXCUri from twisted.internet.interfaces import IDelayedCall @@ -70,6 +71,10 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.media_repository import ( + LocalMedia, + MediaRestrictions, +) from synapse.types import ( JsonDict, PersistedEventPosition, @@ -583,6 +588,7 @@ async def create_event( state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, + mxc_restriction_list_for_event: Optional[List[MXCUri]] = None, ) -> Tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will @@ -637,6 +643,9 @@ async def create_event( current_state_group: the current state group, used only for creating events for batch persisting + mxc_restriction_list_for_event: An optional List of MXCUri objects, to be + used for setting media restrictions + Raises: ResourceLimitError if server is blocked to some resource being exceeded @@ -716,6 +725,11 @@ async def create_event( if txn_id is not None: builder.internal_metadata.txn_id = txn_id + if mxc_restriction_list_for_event is not None: + builder.internal_metadata.media_references = [ + str(mxc) for mxc in mxc_restriction_list_for_event + ] + builder.internal_metadata.outlier = outlier event, unpersisted_context = await self.create_new_client_event( @@ -956,6 +970,7 @@ async def create_and_send_nonmember_event( ignore_shadow_ban: bool = False, outlier: bool = False, depth: Optional[int] = None, + media_info_for_attachment: Optional[set[LocalMedia]] = None, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -984,6 +999,8 @@ async def create_and_send_nonmember_event( depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + media_info_for_attachment: An optional set of LocalMedia objects, for use in + restricting media. Returns: The event, and its stream ordering (if deduplication happened, @@ -1057,6 +1074,7 @@ async def create_and_send_nonmember_event( ignore_shadow_ban=ignore_shadow_ban, outlier=outlier, depth=depth, + media_info_for_attachment=media_info_for_attachment, ) async def _create_and_send_nonmember_event_locked( @@ -1070,6 +1088,7 @@ async def _create_and_send_nonmember_event_locked( ignore_shadow_ban: bool = False, outlier: bool = False, depth: Optional[int] = None, + media_info_for_attachment: Optional[set[LocalMedia]] = None, ) -> Tuple[EventBase, int]: room_id = event_dict["room_id"] @@ -1098,6 +1117,13 @@ async def _create_and_send_nonmember_event_locked( state_event_ids=state_event_ids, outlier=outlier, depth=depth, + mxc_restriction_list_for_event=[ + MXCUri(self.server_name, local_media.media_id) + for local_media in media_info_for_attachment + if local_media + ] + if media_info_for_attachment is not None + else None, ) context = await unpersisted_context.persist(event) @@ -1158,6 +1184,7 @@ async def _create_and_send_nonmember_event_locked( events_and_context=[(event, context)], ratelimit=ratelimit, ignore_shadow_ban=ignore_shadow_ban, + media_info_for_attachment=media_info_for_attachment, ) break @@ -1431,6 +1458,7 @@ async def handle_new_client_event( ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ignore_shadow_ban: bool = False, + media_info_for_attachment: Optional[set[LocalMedia]] = None, ) -> EventBase: """Processes new events. Please note that if batch persisting events, an error in handling any one of these events will result in all of the events being dropped. @@ -1450,6 +1478,9 @@ async def handle_new_client_event( ignore_shadow_ban: True if shadow-banned users should be allowed to send this event. + media_info_for_attachment: An optional set of LocalMedia objects, for use in + restricting media. + Return: If the event was deduplicated, the previous, duplicate, event. Otherwise, `event`. @@ -1460,6 +1491,16 @@ async def handle_new_client_event( a room that has been un-partial stated. """ extra_users = extra_users or [] + media_info_for_attachment = media_info_for_attachment or set() + + # filter for the existing media attachments that were passed in based on the + # mxc. The 'attachments' key can be None, representing that an attachment has + # not been formed yet. If they are all None, will be an empty set + media_restrictions: set[MediaRestrictions] = { + local_media.attachments + for local_media in media_info_for_attachment + if local_media.attachments + } for event, context in events_and_context: # we don't apply shadow-banning to membership events here. Invites are blocked @@ -1482,8 +1523,40 @@ async def handle_new_client_event( event.event_id, prev_event.event_id, ) + if media_restrictions: + # Sort out what event_id's were part of the restrictions. + existing_event_ids_from_media_restrictions = { + res.event_id for res in media_restrictions + } + + # If the de-duplicated event_id matches one of the existing + # restrictions, then all is well. If it does not, then this + # needs to be denied as invalid + if ( + prev_event.event_id + not in existing_event_ids_from_media_restrictions + ): + logger.warning( + "De-duplicated state event '%s' was not already attached to this media", + prev_event.event_id, + ) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "De-duplicated state event was not already attached to this media", + Codes.INVALID_PARAM, + ) + return prev_event + # Some media was trying to be attached to an event, but that media was + # already attached. Deny + if media_restrictions: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"These media ids, '{media_info_for_attachment}' has already been attached to a reference: {media_restrictions}", + Codes.INVALID_PARAM, + ) + if not event.is_state() and event.type in [ EventTypes.Message, EventTypes.Encrypted, @@ -1552,6 +1625,7 @@ async def create_and_send_new_client_events( event_dicts: Sequence[JsonDict], ratelimit: bool = True, ignore_shadow_ban: bool = False, + media_info_for_attachment: Optional[set[LocalMedia]] = None, ) -> None: """Helper to create and send a batch of new client events. @@ -1573,6 +1647,8 @@ async def create_and_send_new_client_events( ratelimit: Whether to rate limit this send. ignore_shadow_ban: True if shadow-banned users should be allowed to send these events. + media_info_for_attachment: An optional set of LocalMedia objects, for use in + restricting media. """ if not event_dicts: @@ -1634,6 +1710,7 @@ async def create_and_send_new_client_events( events_and_context, ignore_shadow_ban=ignore_shadow_ban, ratelimit=ratelimit, + media_info_for_attachment=media_info_for_attachment, ) async def _persist_events( @@ -2056,6 +2133,18 @@ async def persist_and_notify_client_events( events_and_pos = [] for event in persisted_events: + # Access the 'media_references' object from the event internal metadata. + # This will be None if it was not attached during creation of the event. + maybe_media_restrictions_to_set = event.internal_metadata.media_references + + if maybe_media_restrictions_to_set: + for mxc_str in maybe_media_restrictions_to_set: + mxc = MXCUri.from_str(mxc_str) + await self.store.set_media_restrictions( + mxc.server_name, + mxc.media_id, + {"restrictions": {"event_id": event.event_id}}, + ) if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fea25b9920..c58a327681 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -25,6 +25,8 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple +from matrix_common.types.mxc_uri import MXCUri + from synapse import types from synapse.api.constants import ( AccountDataTypes, @@ -52,6 +54,7 @@ from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.push import ReplicationCopyPusherRestServlet +from synapse.storage.databases.main.media_repository import LocalMedia from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.invite_rule import InviteRule from synapse.types import ( @@ -402,6 +405,7 @@ async def _local_membership_update( require_consent: bool = True, outlier: bool = False, origin_server_ts: Optional[int] = None, + media_info_for_attachment: Optional[set[LocalMedia]] = None, ) -> Tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -434,6 +438,8 @@ async def _local_membership_update( opposed to being inline with the current DAG. origin_server_ts: The origin_server_ts to use if a new event is created. Uses the current timestamp if set to None. + media_info_for_attachment: An optional set of LocalMedia objects, for use in + restricting media. Returns: Tuple of event ID and stream ordering position @@ -486,6 +492,13 @@ async def _local_membership_update( depth=depth, require_consent=require_consent, outlier=outlier, + mxc_restriction_list_for_event=[ + MXCUri(self._server_name, local_media.media_id) + for local_media in media_info_for_attachment + if local_media + ] + if media_info_for_attachment is not None + else None, ) context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( @@ -503,6 +516,7 @@ async def _local_membership_update( events_and_context=[(event, context)], extra_users=[target], ratelimit=ratelimit, + media_info_for_attachment=media_info_for_attachment, ) ) @@ -581,6 +595,7 @@ async def update_membership( state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, origin_server_ts: Optional[int] = None, + media_info_for_attachment: Optional[set[LocalMedia]] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -611,6 +626,8 @@ async def update_membership( based on the prev_events. origin_server_ts: The origin_server_ts to use if a new event is created. Uses the current timestamp if set to None. + media_info_for_attachment: An optional set of LocalMedia objects, for use in + restricting media. Returns: A tuple of the new event ID and stream ID. @@ -673,6 +690,7 @@ async def update_membership( state_event_ids=state_event_ids, depth=depth, origin_server_ts=origin_server_ts, + media_info_for_attachment=media_info_for_attachment, ) return result @@ -695,6 +713,7 @@ async def update_membership_locked( state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, origin_server_ts: Optional[int] = None, + media_info_for_attachment: Optional[set[LocalMedia]] = None, ) -> Tuple[str, int]: """Helper for update_membership. @@ -727,6 +746,8 @@ async def update_membership_locked( based on the prev_events. origin_server_ts: The origin_server_ts to use if a new event is created. Uses the current timestamp if set to None. + media_info_for_attachment: An optional set of LocalMedia objects, for use in + restricting media. Returns: A tuple of the new event ID and stream ID. @@ -931,6 +952,7 @@ async def update_membership_locked( require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + media_info_for_attachment=media_info_for_attachment, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) @@ -1189,6 +1211,7 @@ async def update_membership_locked( require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + media_info_for_attachment=media_info_for_attachment, ) async def check_for_any_membership_in_room( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 6b0deda0df..0e09fd36d8 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -25,9 +25,10 @@ import re from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple from urllib import parse as urlparse +from matrix_common.types.mxc_uri import MXCUri from prometheus_client.core import Histogram from twisted.web.server import Request @@ -69,6 +70,7 @@ from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.state import CREATE_KEY, POWER_KEY +from synapse.storage.databases.main.media_repository import LocalMedia from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID from synapse.types.state import StateFilter @@ -191,6 +193,75 @@ def get_room_config(self, request: Request) -> JsonDict: return user_supplied_config +async def validate_attachment_request_and_retrieve_media_info( + requester: Requester, + request: SynapseRequest, + get_local_media_info_cb: Callable[[str], Awaitable[Optional[LocalMedia]]], +) -> set[LocalMedia]: + """ + Parse the request args for potential media attachment parameters. Validate those + parameters are safe and sane, then retrieve the media information for each. + + Args: + requester: The user placing the request, to verify they are allowed + request: The request itself that has the parameters for parsing + get_local_media_info_cb: The callback on the media repository that will retrieve + the media information + Returns: + Return the media info in a set, or an empty set if appropriate + + Raises: + SynapseError: If any of the media is inappropriate or if the requester was not + allowed to attach the media + """ + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + attach_media_list = parse_strings_from_args( + args, "org.matrix.msc3911.attach_media", default=[] + ) + + attach_media_set = set() + for unsafe_mxc in attach_media_list: + # Deduplicate multiple attach media requests with the same MXC + # It may be that the mxc provided does not have the correct scheme + # attached, coax it into the correct form + if not unsafe_mxc.startswith("mxc://"): + unsafe_mxc = f"mxc://{unsafe_mxc}" + attach_media_set.add(MXCUri.from_str(unsafe_mxc)) + + # Check each mxc passed in and raise 400, INVALID_PARAM for: + # * being non-existent + # * not being restricted + # * being from the wrong user + + # XXX: Do we need to watch for race-y conditions where the local media info + # will have changed between now and when it is persisted in a moment? + media_info_for_attachment = set() + for mxc_uri in attach_media_set: + media_info = await get_local_media_info_cb( + mxc_uri.media_id, + ) + + # Do not expose that it was a different user that uploaded the media. Denial of + # metadata leak + if media_info is None or media_info.user_id != requester.user.to_string(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + + if not media_info.restricted: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", + Codes.INVALID_PARAM, + ) + + media_info_for_attachment.add(media_info) + return media_info_for_attachment + + # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(RestServlet): CATEGORY = "Event sending requests" @@ -205,6 +276,9 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self._max_event_delay_ms = hs.config.server.max_event_delay_ms self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker + self.store = hs.get_datastores().main + self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.server_name = hs.config.server.server_name def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/state/$eventtype @@ -301,6 +375,18 @@ async def on_PUT( if txn_id: set_tag("txn_id", txn_id) + media_info_for_attachment: set[LocalMedia] = set() + if self.enable_restricted_media: + # This will raise if any of the attachment parameters or the requester is + # inappropriate + media_info_for_attachment = ( + await validate_attachment_request_and_retrieve_media_info( + requester, + request, + self.store.get_local_media, + ) + ) + content = parse_json_object_from_request(request) is_requester_admin = await self.auth.is_server_admin(requester) @@ -355,6 +441,7 @@ async def on_PUT( action=membership, content=content, origin_server_ts=origin_server_ts, + media_info_for_attachment=media_info_for_attachment, ) else: event_dict: JsonDict = { @@ -374,7 +461,10 @@ async def on_PUT( event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id + requester, + event_dict, + txn_id=txn_id, + media_info_for_attachment=media_info_for_attachment, ) event_id = event.event_id except ShadowBanError: @@ -395,6 +485,9 @@ def __init__(self, hs: "HomeServer"): self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() self._max_event_delay_ms = hs.config.server.max_event_delay_ms + self.store = hs.get_datastores().main + self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.server_name = hs.config.server.server_name def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] @@ -410,6 +503,19 @@ async def _do( txn_id: Optional[str], ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) + # Requirement is only to do this for PUT, but the POST also uses the same + # abstraction. It appears the PUT variant includes a txn_id, perhaps use that? + media_info_for_attachment: set[LocalMedia] = set() + if self.enable_restricted_media: + # This will raise if any of the attachment parameters or the requester is + # inappropriate + media_info_for_attachment = ( + await validate_attachment_request_and_retrieve_media_info( + requester, + request, + self.store.get_local_media, + ) + ) origin_server_ts = None if requester.app_service: @@ -446,7 +552,10 @@ async def _do( event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id + requester, + event_dict, + txn_id=txn_id, + media_info_for_attachment=media_info_for_attachment, ) event_id = event.event_id except ShadowBanError: diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 7d3422572d..9b342d2a00 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -39,6 +39,8 @@ class EventInternalMetadata: """The access token ID of the user who sent this event, if any.""" device_id: str """The device ID of the user who sent this event, if any.""" + media_references: Optional[List[str]] + """The media references that acts as a restriction to this event, if any.""" def get_dict(self) -> JsonDict: ... def is_outlier(self) -> bool: ... diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 48d33b8e17..62d6a4fb2a 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -23,12 +23,15 @@ """Tests REST events for /rooms paths.""" +import io import json +import time from http import HTTPStatus from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union from unittest.mock import AsyncMock, Mock, call, patch from urllib import parse as urlparse +from matrix_common.types.mxc_uri import MXCUri from parameterized import param, parameterized from twisted.test.proto_helpers import MemoryReactor @@ -52,6 +55,7 @@ directory, knock, login, + media, profile, register, room, @@ -65,6 +69,7 @@ from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase +from tests.test_utils import SMALL_PNG from tests.test_utils.event_injection import create_event from tests.unittest import override_config from tests.utils import default_config @@ -4499,3 +4504,607 @@ def test_sending_event_and_leaving_does_not_record_participation( self.store.get_room_participation(self.user2, self.room1) ) self.assertFalse(participant) + + +class RoomStateMediaAttachmentTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + + self.user = self.register_user("david", "password") + self.tok = self.login("david", "password") + + self.other_user = self.register_user("mongo", "password") + self.other_tok = self.login("mongo", "password") + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def create_media_and_set_restricted_flag( + self, user_id: Optional[str] = None + ) -> MXCUri: + """ + Create media without using an endpoint, and set the restricted flag. This will + not add restrictions on its own, as that is the point of this test series + """ + # Allow for testing a different user doing the creation, so can test it errors + # when attaching + if user_id is None: + user_id = self.user + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(user_id), + restricted=True, + ) + ) + return content_uri + + def test_can_attach_media_to_state_event(self) -> None: + """Test basic functionality, that a media ID can be attached to a state event""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(mxc_uri), + }, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + def test_attaching_nonexistent_media_to_state_event_fails(self) -> None: + """Test that media that does not exist is not allowed to be attached to an event""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + nonexistent_mxc_uri = MXCUri.from_str("mxc://test/fakeMediaId") + + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(nonexistent_mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(nonexistent_mxc_uri), + }, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.BAD_REQUEST, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.INVALID_PARAM + + restrictions = self.get_success( + self.store.get_media_restrictions( + nonexistent_mxc_uri.server_name, nonexistent_mxc_uri.media_id + ) + ) + assert restrictions is None, str(restrictions) + + def test_attaching_already_claimed_media_to_state_event_fails(self) -> None: + """Test that attaching media to state event fails if media is already attached""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + # attach media to some other event before we place our state request + self.get_success( + self.store.set_media_restrictions( + mxc_uri.server_name, + mxc_uri.media_id, + {"restrictions": {"event_id": "$some_fake_event_id"}}, + ) + ) + + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(mxc_uri), + }, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.BAD_REQUEST, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.INVALID_PARAM + + def test_state_event_failing_does_not_attach_media(self) -> None: + """Test that state event that should not succeed does not attach the media""" + # Run these two multi-angle tests: + # * m.room.alias points at some other room + # * m.room.avatar is sent by user not in room + + # first test + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/!wrong_room:test/state/m.room.canonical_alias?org.matrix.msc3911.attach_media={str(mxc_uri)}", + {"alias": "#whatever:test"}, + access_token=self.tok, + ) + + self.assertEqual(channel1.code, HTTPStatus.FORBIDDEN, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.FORBIDDEN + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is None, str(restrictions) + + # second test + other_mxc_uri = self.create_media_and_set_restricted_flag(self.other_user) + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(other_mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(other_mxc_uri), + }, + access_token=self.other_tok, + ) + self.assertEqual(channel1.code, HTTPStatus.FORBIDDEN, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.FORBIDDEN + + restrictions = self.get_success( + self.store.get_media_restrictions( + other_mxc_uri.server_name, other_mxc_uri.media_id + ) + ) + assert restrictions is None, str(restrictions) + + def test_state_event_deduplication_does_not_attach_media(self) -> None: + """Test that sending two identical state events does not cause an error with attached media""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(mxc_uri), + }, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + # Now do it again, exactly the same should de-duplicate the event and not error + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(mxc_uri), + }, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + assert channel1.json_body["event_id"] == event_id + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + def test_attaching_media_without_mxc_scheme_does_not_fail(self) -> None: + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + schemeless_mxc_uri = f"{mxc_uri.server_name}/{mxc_uri.media_id}" + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={schemeless_mxc_uri}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(mxc_uri), + }, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + def test_can_attach_multiple_pieces_of_media_to_state_event(self) -> None: + room_id = self.helper.create_room_as(self.user, tok=self.tok) + first_mxc_uri = self.create_media_and_set_restricted_flag() + second_mxc_uri = self.create_media_and_set_restricted_flag() + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(first_mxc_uri)}&org.matrix.msc3911.attach_media={str(second_mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(first_mxc_uri), + }, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions( + first_mxc_uri.server_name, first_mxc_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + restrictions = self.get_success( + self.store.get_media_restrictions( + second_mxc_uri.server_name, second_mxc_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + def test_media_can_not_be_attached_by_user_that_did_not_upload(self) -> None: + """Test that a user attaching media is the same one that uploaded it""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + # without a 'state_key' the url does not need a trailing '/' + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(mxc_uri), + }, + # wrong user + access_token=self.other_tok, + ) + self.assertEqual(channel1.code, HTTPStatus.BAD_REQUEST, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.INVALID_PARAM + + def test_media_can_be_attached_to_member_state_event(self) -> None: + pass + + +class RoomSendEventMediaAttachmentTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + + self.user = self.register_user("david", "password") + self.tok = self.login("david", "password") + + self.other_user = self.register_user("mongo", "password") + self.other_tok = self.login("mongo", "password") + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def create_media_and_set_restricted_flag( + self, user_id: Optional[str] = None + ) -> MXCUri: + """ + Create media without using an endpoint, and set the restricted flag. This will + not add restrictions on its own, as that is the point of this test series + """ + # Allow for testing a different user doing the creation, so can test it errors + # when attaching + if user_id is None: + user_id = self.user + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(user_id), + restricted=True, + ) + ) + return content_uri + + def test_can_attach_media_to_message_event(self) -> None: + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + def test_attaching_nonexistent_media_to_event_fails(self) -> None: + """Test that media that does not exist is not allowed to be attached to an event""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + nonexistent_mxc_uri = MXCUri.from_str("mxc://test/fakeMediaId") + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(nonexistent_mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + # self.assertEqual(channel1.code, HTTPStatus.BAD_REQUEST, channel1.json_body) + self.assertEqual(channel1.code, HTTPStatus.BAD_REQUEST, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.INVALID_PARAM + + restrictions = self.get_success( + self.store.get_media_restrictions( + nonexistent_mxc_uri.server_name, nonexistent_mxc_uri.media_id + ) + ) + assert restrictions is None, str(restrictions) + + def test_attaching_already_claimed_media_to_event_fails(self) -> None: + """Test that attaching media to event fails if media is already attached""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + txn_id = "m%s" % (str(time.time())) + + # attach media to some other event before we place our send event request + self.get_success( + self.store.set_media_restrictions( + mxc_uri.server_name, + mxc_uri.media_id, + {"restrictions": {"event_id": "$some_fake_event_id"}}, + ) + ) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.BAD_REQUEST, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.INVALID_PARAM + + def test_event_failing_does_not_attach_media(self) -> None: + """Test that event that should not succeed does not attach the media""" + # Run these two multi-angle tests: + # * m.room.message points at some other room + # * m.room.message is sent by user not in room + + # first test + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/!wrong_room:test/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + # self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + # assert "event_id" in channel1.json_body + # event_id = channel1.json_body["event_id"] + self.assertEqual(channel1.code, HTTPStatus.FORBIDDEN, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.FORBIDDEN + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is None, str(restrictions) + + # second test + other_mxc_uri = self.create_media_and_set_restricted_flag(self.other_user) + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(other_mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.other_tok, + ) + self.assertEqual(channel1.code, HTTPStatus.FORBIDDEN, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.FORBIDDEN + + restrictions = self.get_success( + self.store.get_media_restrictions( + other_mxc_uri.server_name, other_mxc_uri.media_id + ) + ) + assert restrictions is None, str(restrictions) + + def test_attaching_media_without_mxc_scheme_does_not_fail(self) -> None: + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + schemeless_mxc_uri = f"{mxc_uri.server_name}/{mxc_uri.media_id}" + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={schemeless_mxc_uri}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + def test_can_attach_multiple_pieces_of_media_to_event(self) -> None: + room_id = self.helper.create_room_as(self.user, tok=self.tok) + first_mxc_uri = self.create_media_and_set_restricted_flag() + second_mxc_uri = self.create_media_and_set_restricted_flag() + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(first_mxc_uri)}&org.matrix.msc3911.attach_media={str(second_mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions( + first_mxc_uri.server_name, first_mxc_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + restrictions = self.get_success( + self.store.get_media_restrictions( + second_mxc_uri.server_name, second_mxc_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + def test_media_can_not_be_attached_by_user_that_did_not_upload(self) -> None: + """Test that a user attaching media is the same one that uploaded it""" + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + # wrong user + access_token=self.other_tok, + ) + self.assertEqual(channel1.code, HTTPStatus.BAD_REQUEST, channel1.json_body) + assert "errcode" in channel1.json_body + assert channel1.json_body["errcode"] == Codes.INVALID_PARAM + + def test_idempotency_of_attaching_media_to_message_event(self) -> None: + """Test that a request with exactly the same parameters does not fail""" + # Unlike state events that have a de-duplication mechanism, sending normal + # events has a transaction component. Make sure that acts as expected + room_id = self.helper.create_room_as(self.user, tok=self.tok) + mxc_uri = self.create_media_and_set_restricted_flag() + txn_id = "m%s" % (str(time.time())) + + # First request + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + event_id = channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + # Second request, identical to the first including using the same time for the txn_id + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.tok, + ) + self.assertEqual(channel1.code, HTTPStatus.OK, channel1.json_body) + assert "event_id" in channel1.json_body + # if the event_id returned here matches the one from above, we know it was idempotent + assert event_id == channel1.json_body["event_id"] + + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + +# Sort if need to do annotations and reactions and other m.relates_to stuff here From b1ce63d07f35c2a30833d6adf3ca8c1620d3db77 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Wed, 3 Sep 2025 08:33:34 -0500 Subject: [PATCH 06/35] fix: Ensure that failure to persist an Event does not incorrectly set a media restriction --- synapse/handlers/message.py | 12 --- synapse/storage/databases/main/events.py | 16 ++++ .../databases/main/media_repository.py | 86 +++++++++++++++++-- tests/rest/client/test_rooms.py | 8 +- tests/storage/test_media.py | 28 +++--- 5 files changed, 112 insertions(+), 38 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da1a5724de..e7a8e1efbc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -2133,18 +2133,6 @@ async def persist_and_notify_client_events( events_and_pos = [] for event in persisted_events: - # Access the 'media_references' object from the event internal metadata. - # This will be None if it was not attached during creation of the event. - maybe_media_restrictions_to_set = event.internal_metadata.media_references - - if maybe_media_restrictions_to_set: - for mxc_str in maybe_media_restrictions_to_set: - mxc = MXCUri.from_str(mxc_str) - await self.store.set_media_restrictions( - mxc.server_name, - mxc.media_id, - {"restrictions": {"event_id": event.event_id}}, - ) if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 741146417f..14b23daf81 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -40,6 +40,7 @@ ) import attr +from matrix_common.types.mxc_uri import MXCUri from prometheus_client import Counter import synapse.metrics @@ -2701,6 +2702,21 @@ def _update_metadata_tables_txn( if type(expiry_ts) is int and not event.is_state(): # noqa: E721 self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Access the 'media_references' object from the event internal metadata. + # This will be None if media was not attached during creation of the event. + # Part of MSC3911: Linking media to events + maybe_media_restrictions_to_set = ( + event.internal_metadata.media_references or [] + ) + for mxc_str in maybe_media_restrictions_to_set: + mxc = MXCUri.from_str(mxc_str) + self.store.set_media_restricted_to_event_id_txn( + txn, + server_name=mxc.server_name, + media_id=mxc.media_id, + event_id=event.event_id, + ) + # Insert into the room_memberships table. self._store_room_members_txn( txn, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 8b4be62b3c..e175253d8f 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -19,7 +19,6 @@ # [This file includes modifications made by New Vector Limited] # # -import json import logging from enum import Enum from http import HTTPStatus @@ -47,6 +46,7 @@ LoggingTransaction, ) from synapse.types import JsonDict, UserID +from synapse.util import json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -1163,30 +1163,102 @@ async def get_media_restrictions( return None - async def set_media_restrictions( + async def set_media_restricted_to_event_id( self, server_name: str, media_id: str, - media_restrictions_json: JsonDict, + event_id: str, ) -> None: """ - Add the media restrictions to the database + Add the media restrictions to a given Event ID to the database Args: server_name: media_id: - media_restrictions_json: The media restrictions as dict + event_id: The Event ID to restrict the media to Raises: SynapseError if the media already has restrictions on it """ + await self.db_pool.runInteraction( + "set_media_restricted_to_event_id", + self.set_media_restricted_to_event_id_txn, + server_name=server_name, + media_id=media_id, + event_id=event_id, + ) + + def set_media_restricted_to_event_id_txn( + self, + txn: LoggingTransaction, + *, + server_name: str, + media_id: str, + event_id: str, + ) -> None: + json_object = {"restrictions": {"event_id": event_id}} + try: + self.db_pool.simple_insert_txn( + txn, + "media_attachments", + { + "server_name": server_name, + "media_id": media_id, + "restrictions_json": json_encoder.encode(json_object), + }, + ) + except self.db_pool.engine.module.IntegrityError: + # For sqlite, a unique constraint violation is an integrity error. For + # psycopg2, a UniqueViolation is a subclass of IntegrityError, so this + # covers both. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"This media, '{media_id}' already has restrictions set.", + errcode=Codes.INVALID_PARAM, + ) + + async def set_media_restricted_to_user_profile( + self, + server_name: str, + media_id: str, + profile_user_id: str, + ) -> None: + """ + Add the media restrictions to a given profile for a User ID to the database + + Args: + server_name: + media_id: + profile_user_id: The User ID's profile to restrict the media to + + Raises: + SynapseError if the media already has restrictions on it + """ + await self.db_pool.runInteraction( + "set_media_restricted_to_user_profile", + self.set_media_restricted_to_user_profile_txn, + server_name=server_name, + media_id=media_id, + profile_user_id=profile_user_id, + ) + + def set_media_restricted_to_user_profile_txn( + self, + txn: LoggingTransaction, + *, + server_name: str, + media_id: str, + profile_user_id: str, + ) -> None: + json_object = {"restrictions": {"profile_user_id": profile_user_id}} try: - await self.db_pool.simple_insert( + self.db_pool.simple_insert_txn( + txn, "media_attachments", { "server_name": server_name, "media_id": media_id, - "restrictions_json": json.dumps(media_restrictions_json), + "restrictions_json": json_encoder.encode(json_object), }, ) except self.db_pool.engine.module.IntegrityError: diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 62d6a4fb2a..d7255dee25 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -4615,10 +4615,10 @@ def test_attaching_already_claimed_media_to_state_event_fails(self) -> None: mxc_uri = self.create_media_and_set_restricted_flag() # attach media to some other event before we place our state request self.get_success( - self.store.set_media_restrictions( + self.store.set_media_restricted_to_event_id( mxc_uri.server_name, mxc_uri.media_id, - {"restrictions": {"event_id": "$some_fake_event_id"}}, + "$some_fake_event_id", ) ) @@ -4920,10 +4920,10 @@ def test_attaching_already_claimed_media_to_event_fails(self) -> None: # attach media to some other event before we place our send event request self.get_success( - self.store.set_media_restrictions( + self.store.set_media_restricted_to_event_id( mxc_uri.server_name, mxc_uri.media_id, - {"restrictions": {"event_id": "$some_fake_event_id"}}, + "$some_fake_event_id", ) ) diff --git a/tests/storage/test_media.py b/tests/storage/test_media.py index a3eaa3ce93..d6b9219e3a 100644 --- a/tests/storage/test_media.py +++ b/tests/storage/test_media.py @@ -30,11 +30,10 @@ def prepare( def test_store_and_retrieve_media_restrictions_by_event_id(self) -> None: event_id = "$random_event_id" - media_restrictions = {"restrictions": {"event_id": event_id}} media_id = random_string(24) self.get_success_or_raise( - self.store.set_media_restrictions( - self.server_name, media_id, media_restrictions + self.store.set_media_restricted_to_event_id( + self.server_name, media_id, event_id ) ) @@ -47,11 +46,10 @@ def test_store_and_retrieve_media_restrictions_by_event_id(self) -> None: def test_store_and_retrieve_media_restrictions_by_profile_user_id(self) -> None: user_id = UserID.from_string("@frank:test") - media_restrictions = {"restrictions": {"profile_user_id": user_id.to_string()}} media_id = random_string(24) self.get_success_or_raise( - self.store.set_media_restrictions( - self.server_name, media_id, media_restrictions + self.store.set_media_restricted_to_user_profile( + self.server_name, media_id, user_id.to_string() ) ) @@ -109,12 +107,11 @@ def test_setting_media_restriction_twice_errors( content_uri: str = upload_result["content_uri"] # We can split the content_uri on the last "/" and the rest is the media_id media_id = content_uri.rsplit("/", maxsplit=1)[1] - event_id = "$something_hashy_doesnt_matter" - media_restrictions = {"restrictions": {"event_id": event_id}} + self.get_success( - self.store.set_media_restrictions( - self.server_name, media_id, media_restrictions + self.store.set_media_restricted_to_event_id( + self.server_name, media_id, event_id ) ) @@ -127,8 +124,8 @@ def test_setting_media_restriction_twice_errors( assert existing_media_restrictions is not None self.get_failure( - self.store.set_media_restrictions( - self.server_name, media_id, media_restrictions + self.store.set_media_restricted_to_event_id( + self.server_name, media_id, event_id ), SynapseError, ) @@ -176,7 +173,8 @@ def test_flow(self) -> None: # Create media by using create_or_update_content() helper. This will likely be # on the new `/create` and `/upload` endpoints for msc3911. - # set actual restrictions using storage method `set_media_restrictions()` + # set actual restrictions using storage methods + # `set_media_restricted_to_event_id()` or `set_media_restricted_to_user_profile()` # use `get_local_media()` to retrieve the data @@ -193,10 +191,10 @@ def test_flow(self) -> None: event_id = "$event_id_hash_goes_here" self.get_success( - self.store.set_media_restrictions( + self.store.set_media_restricted_to_event_id( self.server_name, media_id, - {"restrictions": {"event_id": event_id}}, + event_id, ) ) From ed40073ba244db4077a5e03f477c4a9e393915ac Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Wed, 27 Aug 2025 15:06:47 +0200 Subject: [PATCH 07/35] feat: ap4 support attaching media for profile updates --- synapse/handlers/profile.py | 10 ++ synapse/rest/client/profile.py | 56 +++++++- tests/rest/client/test_profile.py | 215 +++++++++++++++++++++++++++++- 3 files changed, 277 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 76f910bf2d..f22ff9195e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -77,6 +77,8 @@ def __init__(self, hs: "HomeServer"): self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules + self.enable_restricted_media = hs.config.experimental.msc3911_enabled + async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: """ Get a user's profile as a JSON dictionary. @@ -332,6 +334,14 @@ async def set_avatar_url( await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) + # msc3911: Update the media restrictions to include the profile user ID + if self.enable_restricted_media and avatar_url_to_set: + await self.hs.get_datastores().main.set_media_restricted_to_user_profile( + self.hs.config.server.server_name, + avatar_url_to_set, + str(target_user), + ) + profile = await self.store.get_profileinfo(target_user) await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 243245f739..d434947c45 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -25,6 +25,8 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from matrix_common.types.mxc_uri import MXCUri + from synapse.api.constants import ProfileFields from synapse.api.errors import Codes, SynapseError from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN @@ -36,7 +38,8 @@ ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, JsonValue, UserID +from synapse.storage.databases.main.media_repository import LocalMedia +from synapse.types import JsonDict, JsonValue, Requester, UserID from synapse.util.stringutils import is_namedspaced_grammar if TYPE_CHECKING: @@ -109,6 +112,47 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() + self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.media_repository = hs.get_media_repository() + + async def _validate_avatar_url_and_retrieve_media_info( + self, avatar_url: str, requester: Requester + ) -> LocalMedia: + """ + Validate avatar_url arg and parse the mxc_uri. Then retrieve the media information. + + Args: + avatar_url: The raw avatar_url arg of request + requester: The user making the request + + Returns: + Return the media info, or None if appropriate + + Raises: + SynapseError: If any of the media is inappropriate or if the requester was not + allowed to attach the media + """ + mxc_uri = MXCUri.from_str(avatar_url) + media_info = await self.media_repository.store.get_local_media(mxc_uri.media_id) + if media_info is None: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + if not media_info.restricted: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", + Codes.INVALID_PARAM, + ) + if media_info.user_id != requester.user.to_string(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + return media_info async def on_GET( self, request: SynapseRequest, user_id: str, field_name: str @@ -203,8 +247,16 @@ async def on_PUT( user, requester, new_value, is_admin, propagate=propagate ) elif field_name == ProfileFields.AVATAR_URL: + media = new_value + if self.enable_restricted_media and new_value: + validated_media = ( + await self._validate_avatar_url_and_retrieve_media_info( + new_value, requester + ) + ) + media = validated_media.media_id await self.profile_handler.set_avatar_url( - user, requester, new_value, is_admin, propagate=propagate + user, requester, media, is_admin, propagate=propagate ) else: await self.profile_handler.set_profile_field( diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 49776d8e8c..d5b263802f 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -21,23 +21,27 @@ """Tests REST events for /profile paths.""" +import io import urllib.parse from http import HTTPStatus from typing import Any, Dict, Optional from canonicaljson import encode_canonical_json +from matrix_common.types.mxc_uri import MXCUri from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource from synapse.api.errors import Codes from synapse.rest import admin -from synapse.rest.client import login, profile, room +from synapse.rest.client import login, media, profile, room from synapse.server import HomeServer from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE -from synapse.types import UserID +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest +from tests.test_utils import SMALL_PNG from tests.utils import USE_POSTGRES_FOR_TESTS @@ -910,3 +914,210 @@ def test_can_lookup_own_profile(self) -> None: access_token=self.requester_tok, ) self.assertEqual(channel.code, 200, channel.result) + + +class ProfileMediaAttachmentTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + profile.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + + self.user = self.register_user("user", "password") + self.tok = self.login("user", "password") + + self.other_user = self.register_user("other_user", "password") + self.other_tok = self.login("other_user", "password") + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def create_media_and_set_restricted_flag(self, user_id: str) -> MXCUri: + """ + Create media without using an endpoint, and set the restricted flag. + """ + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(user_id), + restricted=True, + ) + ) + return content_uri + + def test_can_attach_media_to_profile_update(self) -> None: + """ + Test basic functionality, that a media ID can be attached to a user profile id. + """ + mxc_uri = self.create_media_and_set_restricted_flag(self.user) + # Update user profile with attach_media. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(mxc_uri)}, + ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + + # Check if the media's restrictions field is updated with the profile_user_id. + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id is None + assert restrictions.profile_user_id == UserID.from_string(self.user) + + def test_attaching_nonexistent_media_to_profile_fails(self) -> None: + """ + Test that media that does not exist is not allowed to be attached to a user profile. + """ + # Generate non-existing media. + nonexistent_mxc_uri = MXCUri.from_str("mxc://test/fakeMediaId") + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(nonexistent_mxc_uri)}, + ) + + assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body + assert channel.json_body["errcode"] == Codes.INVALID_PARAM + assert "does not exist" in channel.json_body["error"] + + def test_attaching_unrestricted_media_to_profile_fails(self) -> None: + """ + Test that attaching unrestricted media to user profile fails. + """ + # Create unrestricted media. + channel = self.make_request( + "POST", + "/_matrix/media/v3/upload?filename=test_png_upload", + SMALL_PNG, + access_token=self.tok, + content_type=b"image/png", + custom_headers=[("Content-Length", str(67))], + ) + assert channel.code == 200, channel.result + content_uri = MXCUri.from_str(channel.json_body["content_uri"]) + + # Check media is unrestricted. + media_info = self.get_success(self.store.get_local_media(content_uri.media_id)) + assert media_info is not None + assert not media_info.restricted + + # Try to update user profile with unrestricted media. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(content_uri)}, + ) + assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body + assert channel.json_body["errcode"] == Codes.INVALID_PARAM + assert "is not restricted" in channel.json_body["error"] + + def test_attaching_already_attached_media_to_profile_fails(self) -> None: + """ + Test that attaching already attached media to user profile fails. + """ + mxc_uri = self.create_media_and_set_restricted_flag(self.user) + # Attach the media to the user profile. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(mxc_uri)}, + ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + + # Try attaching the same media again. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(mxc_uri)}, + ) + assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body + assert channel.json_body["errcode"] == Codes.INVALID_PARAM + assert "already has restrictions set" in channel.json_body["error"] + + def test_attaching_not_owned_media_to_profile_fails(self) -> None: + """ + Test that attaching media not owned by the user to profile fails. + """ + # Media is created with other_user. + mxc_uri = self.create_media_and_set_restricted_flag(self.other_user) + # Try to attach the media from other_user to user. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(mxc_uri)}, + ) + assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body + assert channel.json_body["errcode"] == Codes.INVALID_PARAM + assert "does not exist" in channel.json_body["error"] + + def test_remove_media_from_profile(self) -> None: + """ + Test that removing media from user profile works. + """ + mxc_uri = self.create_media_and_set_restricted_flag(self.user) + # Attach the media to the user profile. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(mxc_uri)}, + ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + + # Check media is set as user avatar. + user_avatar = self.get_success( + self.store.get_profile_avatar_url( + UserID.from_string(self.user), + ) + ) + assert user_avatar is not None + assert user_avatar in str(mxc_uri) + + # Remove the media from the user profile. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": ""}, + ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + + # Check media is no longer attached. + user_avatar = self.get_success( + self.store.get_profile_avatar_url( + UserID.from_string(self.user), + ) + ) + assert user_avatar is None From 0dfb823f29ed821b9caab07c79f0fc199a14c080 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Wed, 3 Sep 2025 18:20:37 +0200 Subject: [PATCH 08/35] chore: make validate_media_url_and_retrieve_media_info reusable --- synapse/media/media_repository.py | 36 ++++++++++++++++++++ synapse/rest/client/profile.py | 55 +++---------------------------- tests/rest/client/test_profile.py | 2 +- 3 files changed, 42 insertions(+), 51 deletions(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index c2e04ea767..d7c63e20ef 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -23,6 +23,7 @@ import logging import os import shutil +from http import HTTPStatus from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -533,6 +534,41 @@ async def get_local_media( request, responder, media_type, media_length, upload_name ) + async def validate_media_url_and_retrieve_media_info( + self, media_id: str, requester: Requester + ) -> LocalMedia: + """ + Validate media_id arg and parse the mxc_uri. Then retrieve the media information. + + Args: + media_id: The raw media_id arg of request + requester: The user making the request + + Returns: + Return the media info, or None if appropriate + + Raises: + SynapseError: If any of the media is inappropriate or if the requester was not + allowed to attach the media + """ + if not media_id.startswith("mxc://"): + media_id = f"mxc://{media_id}" + mxc_uri = MXCUri.from_str(media_id) + media_info = await self.store.get_local_media(mxc_uri.media_id) + if media_info is None or media_info.user_id != requester.user.to_string(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + if not media_info.restricted: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", + Codes.INVALID_PARAM, + ) + return media_info + async def get_remote_media( self, request: SynapseRequest, diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index d434947c45..cee8995724 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -25,8 +25,6 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from matrix_common.types.mxc_uri import MXCUri - from synapse.api.constants import ProfileFields from synapse.api.errors import Codes, SynapseError from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN @@ -38,8 +36,7 @@ ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.databases.main.media_repository import LocalMedia -from synapse.types import JsonDict, JsonValue, Requester, UserID +from synapse.types import JsonDict, JsonValue, UserID from synapse.util.stringutils import is_namedspaced_grammar if TYPE_CHECKING: @@ -115,45 +112,6 @@ def __init__(self, hs: "HomeServer"): self.enable_restricted_media = hs.config.experimental.msc3911_enabled self.media_repository = hs.get_media_repository() - async def _validate_avatar_url_and_retrieve_media_info( - self, avatar_url: str, requester: Requester - ) -> LocalMedia: - """ - Validate avatar_url arg and parse the mxc_uri. Then retrieve the media information. - - Args: - avatar_url: The raw avatar_url arg of request - requester: The user making the request - - Returns: - Return the media info, or None if appropriate - - Raises: - SynapseError: If any of the media is inappropriate or if the requester was not - allowed to attach the media - """ - mxc_uri = MXCUri.from_str(avatar_url) - media_info = await self.media_repository.store.get_local_media(mxc_uri.media_id) - if media_info is None: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", - Codes.INVALID_PARAM, - ) - if not media_info.restricted: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", - Codes.INVALID_PARAM, - ) - if media_info.user_id != requester.user.to_string(): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", - Codes.INVALID_PARAM, - ) - return media_info - async def on_GET( self, request: SynapseRequest, user_id: str, field_name: str ) -> Tuple[int, JsonDict]: @@ -247,16 +205,13 @@ async def on_PUT( user, requester, new_value, is_admin, propagate=propagate ) elif field_name == ProfileFields.AVATAR_URL: - media = new_value if self.enable_restricted_media and new_value: - validated_media = ( - await self._validate_avatar_url_and_retrieve_media_info( - new_value, requester - ) + validated_media = await self.media_repository.validate_media_url_and_retrieve_media_info( + new_value, requester ) - media = validated_media.media_id + new_value = validated_media.media_id await self.profile_handler.set_avatar_url( - user, requester, media, is_admin, propagate=propagate + user, requester, new_value, is_admin, propagate=propagate ) else: await self.profile_handler.set_profile_field( diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index d5b263802f..7e703775f5 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -1102,7 +1102,7 @@ def test_remove_media_from_profile(self) -> None: ) ) assert user_avatar is not None - assert user_avatar in str(mxc_uri) + assert "mxc://test/" + user_avatar == str(mxc_uri) # Remove the media from the user profile. channel = self.make_request( From 7c596a19c49049c709d9097870cb2821384f4170 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Wed, 3 Sep 2025 19:05:42 +0200 Subject: [PATCH 09/35] chore: make update profile idempotent --- synapse/media/media_repository.py | 36 -------------------- synapse/rest/client/profile.py | 56 ++++++++++++++++++++++++++++--- tests/rest/client/test_profile.py | 26 +++++++++++++- 3 files changed, 77 insertions(+), 41 deletions(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index d7c63e20ef..c2e04ea767 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -23,7 +23,6 @@ import logging import os import shutil -from http import HTTPStatus from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -534,41 +533,6 @@ async def get_local_media( request, responder, media_type, media_length, upload_name ) - async def validate_media_url_and_retrieve_media_info( - self, media_id: str, requester: Requester - ) -> LocalMedia: - """ - Validate media_id arg and parse the mxc_uri. Then retrieve the media information. - - Args: - media_id: The raw media_id arg of request - requester: The user making the request - - Returns: - Return the media info, or None if appropriate - - Raises: - SynapseError: If any of the media is inappropriate or if the requester was not - allowed to attach the media - """ - if not media_id.startswith("mxc://"): - media_id = f"mxc://{media_id}" - mxc_uri = MXCUri.from_str(media_id) - media_info = await self.store.get_local_media(mxc_uri.media_id) - if media_info is None or media_info.user_id != requester.user.to_string(): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", - Codes.INVALID_PARAM, - ) - if not media_info.restricted: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", - Codes.INVALID_PARAM, - ) - return media_info - async def get_remote_media( self, request: SynapseRequest, diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index cee8995724..f48c437568 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -25,6 +25,8 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from matrix_common.types.mxc_uri import MXCUri + from synapse.api.constants import ProfileFields from synapse.api.errors import Codes, SynapseError from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN @@ -36,7 +38,8 @@ ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, JsonValue, UserID +from synapse.storage.databases.main.media_repository import LocalMedia +from synapse.types import JsonDict, JsonValue, Requester, UserID from synapse.util.stringutils import is_namedspaced_grammar if TYPE_CHECKING: @@ -112,6 +115,41 @@ def __init__(self, hs: "HomeServer"): self.enable_restricted_media = hs.config.experimental.msc3911_enabled self.media_repository = hs.get_media_repository() + async def validate_avatar_url_and_retrieve_media_info( + self, avatar_url: str, requester: Requester + ) -> LocalMedia: + """ + Validate avatar_url arg and parse the mxc_uri. Then retrieve the media information. + + Args: + avatar_url: The raw avatar_url arg of request + requester: The user making the request + + Returns: + Return the media info, or None if appropriate + + Raises: + SynapseError: If any of the media is inappropriate or if the requester was not + allowed to attach the media + """ + if not avatar_url.startswith("mxc://"): + avatar_url = f"mxc://{avatar_url}" + mxc_uri = MXCUri.from_str(avatar_url) + media_info = await self.media_repository.store.get_local_media(mxc_uri.media_id) + if media_info is None or media_info.user_id != requester.user.to_string(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + if not media_info.restricted: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", + Codes.INVALID_PARAM, + ) + return media_info + async def on_GET( self, request: SynapseRequest, user_id: str, field_name: str ) -> Tuple[int, JsonDict]: @@ -199,15 +237,25 @@ async def on_PUT( "Updating profile while account is suspended is not allowed.", Codes.USER_ACCOUNT_SUSPENDED, ) - if field_name == ProfileFields.DISPLAYNAME: await self.profile_handler.set_displayname( user, requester, new_value, is_admin, propagate=propagate ) elif field_name == ProfileFields.AVATAR_URL: if self.enable_restricted_media and new_value: - validated_media = await self.media_repository.validate_media_url_and_retrieve_media_info( - new_value, requester + current_avatar_url = ( + await self.profile_handler.store.get_profile_avatar_url( + requester.user + ) + ) + if current_avatar_url and new_value == str( + MXCUri(self.hs.hostname, current_avatar_url) + ): + return 200, {} + validated_media = ( + await self.validate_avatar_url_and_retrieve_media_info( + new_value, requester + ) ) new_value = validated_media.media_id await self.profile_handler.set_avatar_url( diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 7e703775f5..321fd3b1ee 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -1052,13 +1052,37 @@ def test_attaching_already_attached_media_to_profile_fails(self) -> None: assert channel.code == HTTPStatus.OK assert channel.json_body == {} - # Try attaching the same media again. + # Try attaching the same media again. It is idempotent operation so does not return any error. channel = self.make_request( "PUT", f"/_matrix/client/v3/profile/{self.user}/avatar_url", access_token=self.tok, content={"avatar_url": str(mxc_uri)}, ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + + # Try attaching other media that already has restrictions from other users fails. + already_attached = self.create_media_and_set_restricted_flag(self.user) + self.get_success_or_raise( + self.store.set_media_restricted_to_event_id( + self.server_name, already_attached.media_id, "$random_event_id" + ) + ) + retrieved_restrictions = self.get_success_or_raise( + self.store.get_media_restrictions( + self.server_name, already_attached.media_id + ) + ) + assert retrieved_restrictions is not None + assert retrieved_restrictions.event_id == "$random_event_id" + + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(already_attached)}, + ) assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body assert channel.json_body["errcode"] == Codes.INVALID_PARAM assert "already has restrictions set" in channel.json_body["error"] From bb9e33efc57cdbf071c62be735f3f7a06f0e1c68 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Thu, 4 Sep 2025 11:14:00 +0200 Subject: [PATCH 10/35] chore: update validation --- synapse/handlers/profile.py | 4 +- synapse/rest/client/profile.py | 99 ++++++++++--------- .../databases/main/media_repository.py | 7 +- tests/rest/client/test_profile.py | 45 ++++++++- 4 files changed, 101 insertions(+), 54 deletions(-) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index f22ff9195e..b0730fb6c0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -332,8 +332,6 @@ async def set_avatar_url( target_user, authenticated_entity=requester.authenticated_entity ) - await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) - # msc3911: Update the media restrictions to include the profile user ID if self.enable_restricted_media and avatar_url_to_set: await self.hs.get_datastores().main.set_media_restricted_to_user_profile( @@ -342,6 +340,8 @@ async def set_avatar_url( str(target_user), ) + await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) + profile = await self.store.get_profileinfo(target_user) await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index f48c437568..bfdf83e9ce 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -38,7 +38,6 @@ ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.databases.main.media_repository import LocalMedia from synapse.types import JsonDict, JsonValue, Requester, UserID from synapse.util.stringutils import is_namedspaced_grammar @@ -113,43 +112,11 @@ def __init__(self, hs: "HomeServer"): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.disable_unrestricted_media = ( + hs.config.experimental.msc3911_unrestricted_media_upload_disabled + ) self.media_repository = hs.get_media_repository() - async def validate_avatar_url_and_retrieve_media_info( - self, avatar_url: str, requester: Requester - ) -> LocalMedia: - """ - Validate avatar_url arg and parse the mxc_uri. Then retrieve the media information. - - Args: - avatar_url: The raw avatar_url arg of request - requester: The user making the request - - Returns: - Return the media info, or None if appropriate - - Raises: - SynapseError: If any of the media is inappropriate or if the requester was not - allowed to attach the media - """ - if not avatar_url.startswith("mxc://"): - avatar_url = f"mxc://{avatar_url}" - mxc_uri = MXCUri.from_str(avatar_url) - media_info = await self.media_repository.store.get_local_media(mxc_uri.media_id) - if media_info is None or media_info.user_id != requester.user.to_string(): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", - Codes.INVALID_PARAM, - ) - if not media_info.restricted: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", - Codes.INVALID_PARAM, - ) - return media_info - async def on_GET( self, request: SynapseRequest, user_id: str, field_name: str ) -> Tuple[int, JsonDict]: @@ -188,6 +155,54 @@ async def on_GET( return 200, {field_name: field_value} + async def validate_avatar_url(self, avatar_url: str, requester: Requester) -> None: + """ + Validate avatar_url to make sure the media is owned by the requester or media + is already attached to other event or profile. + + Args: + avatar_url: The raw avatar_url arg of request + requester: The user making the request + + Returns: + Return None when all the validations pass + + Raises: + SynapseError: If any of the media is inappropriate or if the requester was not + allowed to attach the media + """ + if not avatar_url.startswith("mxc://"): + avatar_url = f"mxc://{avatar_url}" + mxc_uri = MXCUri.from_str(avatar_url) + + media_info = await self.media_repository.store.get_local_media(mxc_uri.media_id) + if media_info is None or media_info.user_id != requester.user.to_string(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + if self.disable_unrestricted_media and not media_info.restricted: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", + Codes.INVALID_PARAM, + ) + if ( + media_info.restricted + and media_info.attachments + and ( + media_info.attachments.event_id + or media_info.attachments.profile_user_id + ) + ): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is already attached", + Codes.INVALID_PARAM, + ) + return + async def on_PUT( self, request: SynapseRequest, user_id: str, field_name: str ) -> Tuple[int, JsonDict]: @@ -248,16 +263,10 @@ async def on_PUT( requester.user ) ) - if current_avatar_url and new_value == str( - MXCUri(self.hs.hostname, current_avatar_url) - ): + # If new_value is the same as existing one, keep the function idempotent + if current_avatar_url and str(current_avatar_url) == new_value: return 200, {} - validated_media = ( - await self.validate_avatar_url_and_retrieve_media_info( - new_value, requester - ) - ) - new_value = validated_media.media_id + await self.validate_avatar_url(new_value, requester) await self.profile_handler.set_avatar_url( user, requester, new_value, is_admin, propagate=propagate ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index e175253d8f..55dcbd6fe7 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1227,13 +1227,16 @@ async def set_media_restricted_to_user_profile( Add the media restrictions to a given profile for a User ID to the database Args: - server_name: - media_id: + server_name: The name of the server + media_id: The media ID that doesn't contain "mxc://{servername}/" profile_user_id: The User ID's profile to restrict the media to Raises: SynapseError if the media already has restrictions on it """ + if "mxc://" in media_id: + # parse mxc_uri into media_id + media_id = media_id.split("/")[-1] await self.db_pool.runInteraction( "set_media_restricted_to_user_profile", self.set_media_restricted_to_user_profile_txn, diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 321fd3b1ee..7528c07661 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -42,6 +42,7 @@ from tests import unittest from tests.test_utils import SMALL_PNG +from tests.unittest import override_config from tests.utils import USE_POSTGRES_FOR_TESTS @@ -979,7 +980,6 @@ def test_can_attach_media_to_profile_update(self) -> None: ) assert channel.code == HTTPStatus.OK assert channel.json_body == {} - # Check if the media's restrictions field is updated with the profile_user_id. restrictions = self.get_success( self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) @@ -1005,9 +1005,9 @@ def test_attaching_nonexistent_media_to_profile_fails(self) -> None: assert channel.json_body["errcode"] == Codes.INVALID_PARAM assert "does not exist" in channel.json_body["error"] - def test_attaching_unrestricted_media_to_profile_fails(self) -> None: + def test_attaching_unrestricted_media_to_profile(self) -> None: """ - Test that attaching unrestricted media to user profile fails. + Test that attaching unrestricted media to user profile also works """ # Create unrestricted media. channel = self.make_request( @@ -1026,6 +1026,41 @@ def test_attaching_unrestricted_media_to_profile_fails(self) -> None: assert media_info is not None assert not media_info.restricted + # Try to update user profile with unrestricted media. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(content_uri)}, + ) + assert channel.code == 200, channel.result + + @override_config( + {"experimental_features": {"msc3911_unrestricted_media_upload_disabled": True}} + ) + def test_attaching_unrestricted_media_to_profile_fails(self) -> None: + """ + Test that attaching unrestricted media to user profile fails when unrestircted + media is banned by configuration. + """ + # Create unrestricted media. + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(self.user), + restricted=False, + ) + ) + + # Check media is unrestricted. + media_info = self.get_success(self.store.get_local_media(content_uri.media_id)) + assert media_info is not None + assert not media_info.restricted + # Try to update user profile with unrestricted media. channel = self.make_request( "PUT", @@ -1085,7 +1120,7 @@ def test_attaching_already_attached_media_to_profile_fails(self) -> None: ) assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body assert channel.json_body["errcode"] == Codes.INVALID_PARAM - assert "already has restrictions set" in channel.json_body["error"] + assert "already attached" in channel.json_body["error"] def test_attaching_not_owned_media_to_profile_fails(self) -> None: """ @@ -1126,7 +1161,7 @@ def test_remove_media_from_profile(self) -> None: ) ) assert user_avatar is not None - assert "mxc://test/" + user_avatar == str(mxc_uri) + assert user_avatar == str(mxc_uri) # Remove the media from the user profile. channel = self.make_request( From 020eb8e0848763a6c076d2a74b66a21e0041a394 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Thu, 4 Sep 2025 11:22:48 +0200 Subject: [PATCH 11/35] chore: validate avatar url in handler --- synapse/handlers/profile.py | 57 ++++++++++++++++++++++++++++++++++ synapse/rest/client/profile.py | 56 +-------------------------------- 2 files changed, 58 insertions(+), 55 deletions(-) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b0730fb6c0..f8ff2dc947 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -20,8 +20,11 @@ # import logging import random +from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Union +from matrix_common.types.mxc_uri import MXCUri + from synapse.api.constants import ProfileFields from synapse.api.errors import ( AuthError, @@ -78,6 +81,9 @@ def __init__(self, hs: "HomeServer"): self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.disable_unrestricted_media = ( + hs.config.experimental.msc3911_unrestricted_media_upload_disabled + ) async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: """ @@ -275,6 +281,56 @@ async def get_avatar_url(self, target_user: UserID) -> Optional[str]: return result.get("avatar_url") + async def validate_avatar_url(self, avatar_url: str, requester: Requester) -> None: + """ + Validate avatar_url to make sure the media is owned by the requester or media + is already attached to other event or profile. + + Args: + avatar_url: The raw avatar_url arg of request + requester: The user making the request + + Returns: + Return None when all the validations pass + + Raises: + SynapseError: If any of the media is inappropriate or if the requester was not + allowed to attach the media + """ + if not avatar_url.startswith("mxc://"): + avatar_url = f"mxc://{avatar_url}" + mxc_uri = MXCUri.from_str(avatar_url) + + media_info = await self.hs.get_datastores().main.get_local_media( + mxc_uri.media_id + ) + if media_info is None or media_info.user_id != requester.user.to_string(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + if self.disable_unrestricted_media and not media_info.restricted: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", + Codes.INVALID_PARAM, + ) + if ( + media_info.restricted + and media_info.attachments + and ( + media_info.attachments.event_id + or media_info.attachments.profile_user_id + ) + ): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is already attached", + Codes.INVALID_PARAM, + ) + return + async def set_avatar_url( self, target_user: UserID, @@ -334,6 +390,7 @@ async def set_avatar_url( # msc3911: Update the media restrictions to include the profile user ID if self.enable_restricted_media and avatar_url_to_set: + await self.validate_avatar_url(avatar_url_to_set, requester) await self.hs.get_datastores().main.set_media_restricted_to_user_profile( self.hs.config.server.server_name, avatar_url_to_set, diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index bfdf83e9ce..45c40120ca 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -25,8 +25,6 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from matrix_common.types.mxc_uri import MXCUri - from synapse.api.constants import ProfileFields from synapse.api.errors import Codes, SynapseError from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN @@ -38,7 +36,7 @@ ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, JsonValue, Requester, UserID +from synapse.types import JsonDict, JsonValue, UserID from synapse.util.stringutils import is_namedspaced_grammar if TYPE_CHECKING: @@ -112,9 +110,6 @@ def __init__(self, hs: "HomeServer"): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() self.enable_restricted_media = hs.config.experimental.msc3911_enabled - self.disable_unrestricted_media = ( - hs.config.experimental.msc3911_unrestricted_media_upload_disabled - ) self.media_repository = hs.get_media_repository() async def on_GET( @@ -155,54 +150,6 @@ async def on_GET( return 200, {field_name: field_value} - async def validate_avatar_url(self, avatar_url: str, requester: Requester) -> None: - """ - Validate avatar_url to make sure the media is owned by the requester or media - is already attached to other event or profile. - - Args: - avatar_url: The raw avatar_url arg of request - requester: The user making the request - - Returns: - Return None when all the validations pass - - Raises: - SynapseError: If any of the media is inappropriate or if the requester was not - allowed to attach the media - """ - if not avatar_url.startswith("mxc://"): - avatar_url = f"mxc://{avatar_url}" - mxc_uri = MXCUri.from_str(avatar_url) - - media_info = await self.media_repository.store.get_local_media(mxc_uri.media_id) - if media_info is None or media_info.user_id != requester.user.to_string(): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", - Codes.INVALID_PARAM, - ) - if self.disable_unrestricted_media and not media_info.restricted: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", - Codes.INVALID_PARAM, - ) - if ( - media_info.restricted - and media_info.attachments - and ( - media_info.attachments.event_id - or media_info.attachments.profile_user_id - ) - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is already attached", - Codes.INVALID_PARAM, - ) - return - async def on_PUT( self, request: SynapseRequest, user_id: str, field_name: str ) -> Tuple[int, JsonDict]: @@ -266,7 +213,6 @@ async def on_PUT( # If new_value is the same as existing one, keep the function idempotent if current_avatar_url and str(current_avatar_url) == new_value: return 200, {} - await self.validate_avatar_url(new_value, requester) await self.profile_handler.set_avatar_url( user, requester, new_value, is_admin, propagate=propagate ) From 64483d6a7b033f51078deb2e3a91b2f8ed7cb89a Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Thu, 4 Sep 2025 13:05:47 +0200 Subject: [PATCH 12/35] fix: remove unused repository --- synapse/rest/client/profile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 45c40120ca..0a33cd305a 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -110,7 +110,6 @@ def __init__(self, hs: "HomeServer"): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() self.enable_restricted_media = hs.config.experimental.msc3911_enabled - self.media_repository = hs.get_media_repository() async def on_GET( self, request: SynapseRequest, user_id: str, field_name: str From 510a4d5e62ca1bee605316bfc6581747d368b34d Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Fri, 29 Aug 2025 18:11:44 +0200 Subject: [PATCH 13/35] feat: add copy api --- synapse/rest/client/media.py | 129 +++++++++++++++++++++++++++++++- tests/rest/client/test_media.py | 83 +++++++++++++++++++- 2 files changed, 209 insertions(+), 3 deletions(-) diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 5ff0178416..50242cee19 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -22,8 +22,9 @@ import logging import re -from typing import Optional +from typing import Optional, Union +from synapse.api.errors import Codes, SynapseError from synapse.http.server import ( HttpServer, respond_with_json, @@ -31,7 +32,12 @@ set_corp_headers, set_cors_headers, ) -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) from synapse.http.site import SynapseRequest from synapse.media._base import ( DEFAULT_MAX_TIMEOUT_MS, @@ -44,6 +50,8 @@ from synapse.rest.media.create_resource import CreateResource from synapse.rest.media.upload_resource import UploadRestrictedResource from synapse.server import HomeServer +from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia +from synapse.types import Requester from synapse.util.stringutils import parse_and_validate_server_name logger = logging.getLogger(__name__) @@ -277,6 +285,122 @@ async def on_GET( ) +class CopyResource(RestServlet): + """ + MSC3911: This is an unstable endpoint that is introduced in msc3911 scope. This + "copy" api is to be used by clients when forwarding events with media attachments. + Rather than just allowing clients to attach media to multiple events, this ensures + that the list of events attached to a media does not grow over time, so that servers + can reliably cache media and impose the correct access restrictions. + """ + + # Stable: /_matrix/client/v1/media/copy/{serverName}/{mediaId} + PATTERNS = [ + re.compile( + "/_matrix/client/unstable/org.matrix.msc3911/media/copy/(?P[^/]*)/(?P[^/]*)" + ) + ] + + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): + super().__init__() + self.store = hs.get_datastores().main + self.media_repo = media_repo + self.auth = hs.get_auth() + self._is_mine_server_name = hs.is_mine_server_name + self.limits_dict = {"m.upload.size": hs.config.media.max_upload_size} + self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository + self.clock = hs.get_clock() + + async def _validate_user_media_limit( + self, requester: Requester, media_info: Union[LocalMedia, RemoteMedia, None] + ) -> None: + """Check if the request exceeds the user's media limits.""" + media_config = await self.media_repository_callbacks.get_media_config_for_user( + requester.user.to_string(), + ) + if not media_config: + media_config = self.limits_dict + + max_upload_size = media_config.get("m.upload.size") + if max_upload_size and media_info and media_info.media_length: + # QUESTION: do we need to also take the already uploaded data amount into account? + if media_info.media_length > max_upload_size: + raise SynapseError(400, Codes.RESOURCE_LIMIT_EXCEEDED) + + async def on_POST( + self, + request: SynapseRequest, + server_name: str, + media_id: str, + ) -> None: + """ + Handles copying a media item referenced by server_name and media_id. + Returns a new MXC URI for the copied media. + """ + requester = await self.auth.get_user_by_req(request) + + # Optionally parse request body (must be a JSON object, but no required params) + # QUESTION: Not sure what information the content is carrying. events info? + content = parse_json_object_from_request(request, allow_empty_body=True) # noqa F841 + + # Check if media exists and get media info + local_media = False + media_info: Union[LocalMedia, RemoteMedia, None] = None + if self._is_mine_server_name(server_name): + local_media = True + media_info = await self.store.get_local_media(media_id) + else: + media_info = await self.media_repo.get_remote_media_info( + server_name, + media_id, + MAXIMUM_ALLOWED_MAX_TIMEOUT_MS, + request.getClientAddress().host, + False, # Not sure this is correct value for use_federation + True, # Not sure this is correct value for allow_authenticated + ) + await self._validate_user_media_limit(requester, media_info) + + # Creates new copy of media item. (New reference to an existing item) + # Storage might be shared in the future (by storing via a content hash) + if media_info and local_media: + try: + # QUESTION: remote media copies are also stored in local_media_repository? + mxc_uri, _ = await self.media_repo.create_media_id( + requester.user, restricted=True + ) + if media_info.media_length and media_info.sha256: + await self.store.update_local_media( + media_id=mxc_uri.split("/")[-1], + media_type=media_info.media_type, + upload_name=media_info.upload_name, + media_length=media_info.media_length, + user_id=requester.user, + sha256=media_info.sha256, + quarantined_by=None, # QUESTION: Not sure how quarantine media works + ) + + # Jetzt media is in pending state. + + # When copying media it should be in the unattached state until the user manually attaches it to a new event. + # New Media reference can be attached to a new event. (like uploading a new media) + # If attach succeed, response with json object with a required content_uri, giving a new MXC URI referring to the media. + + # TODO: attach to the event logic. where does event info coming from? body param? + + # Respond with the new MXC URI + respond_with_json( + request, + 200, + {"content_uri": mxc_uri}, + send_cors=True, + ) + + except Exception as e: + logger.error("Failed to copy media: %s", e) + respond_with_json(request, 500, {"error": "Failed to copy media"}) + return + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: media_repo = hs.get_media_repository() if hs.config.media.url_preview_enabled: @@ -289,3 +413,4 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc3911_enabled: CreateResource(hs, media_repo, restricted=True).register(http_server) UploadRestrictedResource(hs, media_repo).register(http_server) + CopyResource(hs, media_repo).register(http_server) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index ba38b56041..e57d46ddac 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -3029,7 +3029,7 @@ def test_unrestricted_resource_upload_disabled(self) -> None: ) -class RestrictedResourceTestCase(unittest.HomeserverTestCase): +class RestrictedResourceUploadTestCase(unittest.HomeserverTestCase): """ Tests restricted media creation and upload endpoints when `msc3911_enabled` is configured to be True. @@ -3176,3 +3176,84 @@ def test_async_upload_restricted_resource(self) -> None: access_token=self.other_user_tok, ) assert channel.code == 404 + + +class CopyRestrictedResource(unittest.HomeserverTestCase): + """ + Tests copy API when `msc3911_enabled` is configured to be True. + """ + + extra_config = { + "experimental_features": {"msc3911_enabled": True}, + } + + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config.update(self.extra_config) + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.media_repo = hs.get_media_repository_resource() + self.user = self.register_user("user", "testpass") + self.user_tok = self.login("user", "testpass") + + self.other_user = self.register_user("other", "testpass") + self.other_user_tok = self.login("other", "testpass") + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def _create_restricted_resource(self) -> str: + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc3911/media/upload?filename=test_png_upload", + content=SMALL_PNG, + content_type=b"image/png", + access_token=self.user_tok, + custom_headers=[("Content-Length", str(67))], + ) + self.assertEqual(channel.code, 200) + self.assertIn("content_uri", channel.json_body) + return channel.json_body["content_uri"].split("/")[-1] + + def test_copy_restricted_resource(self) -> None: + """ + Tests that the new copy endpoint creates a new mxc uri for restricted resource. + """ + # media is created with user_tok + media_id = self._create_restricted_resource() + # copy request from other_user + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 200) + self.assertIn("content_uri", channel.json_body) + new_media_id = channel.json_body["content_uri"].split("/")[-1] + assert new_media_id != media_id + + # Check if the original media there. + original_media = self.get_success( + self.hs.get_datastores().main.get_local_media(media_id) + ) + assert original_media is not None + assert original_media.user_id == self.user + + # Check the copied media. + copied_media = self.get_success( + self.hs.get_datastores().main.get_local_media(new_media_id) + ) + assert copied_media is not None + assert copied_media.user_id == self.other_user + + # Check if they are referencing the same image. + assert original_media.sha256 == copied_media.sha256 From 65bfb2b39db77f8a348050084fb804d02ccbde05 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Thu, 4 Sep 2025 16:22:31 +0200 Subject: [PATCH 14/35] chore: add more test --- synapse/rest/client/media.py | 52 +++++++++--------- tests/rest/client/test_media.py | 94 +++++++++++++++++++++++++++------ 2 files changed, 102 insertions(+), 44 deletions(-) diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 50242cee19..41edfdb33d 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -24,7 +24,11 @@ import re from typing import Optional, Union -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import ( + Codes, + NotFoundError, + SynapseError, +) from synapse.http.server import ( HttpServer, respond_with_json, @@ -323,7 +327,7 @@ async def _validate_user_media_limit( max_upload_size = media_config.get("m.upload.size") if max_upload_size and media_info and media_info.media_length: - # QUESTION: do we need to also take the already uploaded data amount into account? + # We are not counting the amount of media the user uploaded in a previous time period if media_info.media_length > max_upload_size: raise SynapseError(400, Codes.RESOURCE_LIMIT_EXCEEDED) @@ -337,34 +341,38 @@ async def on_POST( Handles copying a media item referenced by server_name and media_id. Returns a new MXC URI for the copied media. """ + max_timeout_ms = parse_integer( + request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS + ) + max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) + requester = await self.auth.get_user_by_req(request) - # Optionally parse request body (must be a JSON object, but no required params) - # QUESTION: Not sure what information the content is carrying. events info? - content = parse_json_object_from_request(request, allow_empty_body=True) # noqa F841 + # Optionally parse request body, must be a JSON object, but no required params. + _ = parse_json_object_from_request(request, allow_empty_body=True) - # Check if media exists and get media info - local_media = False media_info: Union[LocalMedia, RemoteMedia, None] = None if self._is_mine_server_name(server_name): - local_media = True media_info = await self.store.get_local_media(media_id) else: media_info = await self.media_repo.get_remote_media_info( server_name, media_id, - MAXIMUM_ALLOWED_MAX_TIMEOUT_MS, + max_timeout_ms, request.getClientAddress().host, - False, # Not sure this is correct value for use_federation - True, # Not sure this is correct value for allow_authenticated + use_federation=True, # Not sure this is correct value for use_federation + allow_authenticated=True, ) + + if not media_info: + raise NotFoundError() + if media_info.quarantined_by: + raise NotFoundError() + await self._validate_user_media_limit(requester, media_info) - # Creates new copy of media item. (New reference to an existing item) - # Storage might be shared in the future (by storing via a content hash) - if media_info and local_media: + if media_info: try: - # QUESTION: remote media copies are also stored in local_media_repository? mxc_uri, _ = await self.media_repo.create_media_id( requester.user, restricted=True ) @@ -376,29 +384,17 @@ async def on_POST( media_length=media_info.media_length, user_id=requester.user, sha256=media_info.sha256, - quarantined_by=None, # QUESTION: Not sure how quarantine media works + quarantined_by=None, ) - - # Jetzt media is in pending state. - - # When copying media it should be in the unattached state until the user manually attaches it to a new event. - # New Media reference can be attached to a new event. (like uploading a new media) - # If attach succeed, response with json object with a required content_uri, giving a new MXC URI referring to the media. - - # TODO: attach to the event logic. where does event info coming from? body param? - - # Respond with the new MXC URI respond_with_json( request, 200, {"content_uri": mxc_uri}, send_cors=True, ) - except Exception as e: logger.error("Failed to copy media: %s", e) respond_with_json(request, 500, {"error": "Failed to copy media"}) - return def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index e57d46ddac..bffbd982ef 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -3199,10 +3199,9 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(config=config) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() + self.media_repo = hs.get_media_repository() self.user = self.register_user("user", "testpass") self.user_tok = self.login("user", "testpass") - self.other_user = self.register_user("other", "testpass") self.other_user_tok = self.login("other", "testpass") @@ -3211,29 +3210,87 @@ def create_resource_dict(self) -> dict[str, Resource]: resources["/_matrix/media"] = self.hs.get_media_repository_resource() return resources - def _create_restricted_resource(self) -> str: + def test_copy_local_restricted_resource(self) -> None: + """ + Tests that the new copy endpoint creates a new mxc uri for restricted resource. + """ + # The media is created with user_tok + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(self.user), + restricted=True, + ) + ) + media_id = content_uri.media_id + + # The other_user copies the media from local server channel = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc3911/media/upload?filename=test_png_upload", - content=SMALL_PNG, - content_type=b"image/png", - access_token=self.user_tok, - custom_headers=[("Content-Length", str(67))], + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{media_id}", + access_token=self.other_user_tok, ) self.assertEqual(channel.code, 200) self.assertIn("content_uri", channel.json_body) - return channel.json_body["content_uri"].split("/")[-1] + new_media_id = channel.json_body["content_uri"].split("/")[-1] + assert new_media_id != media_id + + # Check if the original media there. + original_media = self.get_success( + self.hs.get_datastores().main.get_local_media(media_id) + ) + assert original_media is not None + assert original_media.user_id == self.user - def test_copy_restricted_resource(self) -> None: + # Check the copied media. + copied_media = self.get_success( + self.hs.get_datastores().main.get_local_media(new_media_id) + ) + assert copied_media is not None + assert copied_media.user_id == self.other_user + + # Check if they are referencing the same image. + assert original_media.sha256 == copied_media.sha256 + + # Check if media is unattached to any event or user profile yet. + assert copied_media.attachments is None + + def test_copy_remote_restricted_resource(self) -> None: """ Tests that the new copy endpoint creates a new mxc uri for restricted resource. """ - # media is created with user_tok - media_id = self._create_restricted_resource() - # copy request from other_user + # create remote media + remote_server = "remoteserver.com" + remote_file_id = "remote1" + file_info = FileInfo(server_name=remote_server, file_id=remote_file_id) + + media_storage = self.hs.get_media_repository().media_storage + ctx = media_storage.store_into_file(file_info) + (f, _) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + media_id = "remotemedia" + self.get_success( + self.hs.get_datastores().main.store_cached_remote_media( + origin=remote_server, + media_id=media_id, + media_type="image/png", + media_length=1, + time_now_ms=self.clock.time_msec(), + upload_name="test.png", + filesystem_id=remote_file_id, + sha256=remote_file_id, + ) + ) + + # The other_user copies the media from remote server channel = self.make_request( "POST", - f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{media_id}", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{remote_server}/{media_id}", access_token=self.other_user_tok, ) self.assertEqual(channel.code, 200) @@ -3243,10 +3300,12 @@ def test_copy_restricted_resource(self) -> None: # Check if the original media there. original_media = self.get_success( - self.hs.get_datastores().main.get_local_media(media_id) + self.hs.get_datastores().main.get_cached_remote_media( + remote_server, media_id + ) ) assert original_media is not None - assert original_media.user_id == self.user + assert original_media.upload_name == "test.png" # Check the copied media. copied_media = self.get_success( @@ -3257,3 +3316,6 @@ def test_copy_restricted_resource(self) -> None: # Check if they are referencing the same image. assert original_media.sha256 == copied_media.sha256 + + # Check if copied media is unattached to any event or user profile yet. + assert copied_media.attachments is None From c9aee372dbc7b856d00226ca7d08ca5f2573a675 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Tue, 2 Sep 2025 12:08:12 -0500 Subject: [PATCH 15/35] feat: msc3911[AP5] - Update room creation handler to recognize and attach restricted m.room.avatar media --- synapse/handlers/room.py | 54 +++++++++++ tests/rest/client/test_rooms.py | 165 +++++++++++++++++++++++++++++++- 2 files changed, 215 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 326f04a2ff..7451feca16 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -41,6 +41,7 @@ ) import attr +from matrix_common.types.mxc_uri import MXCUri import synapse.events.snapshot from synapse.api.constants import ( @@ -1498,6 +1499,58 @@ async def create_event( event_dict.update(event_keys) event_dict.update(kwargs) + mxc_restrictions = None + if ( + self.config.experimental.msc3911_enabled + and etype == EventTypes.RoomAvatar + ): + # this should be an mxc, but the spec does not specifically say it has to be + extracted_media_id: Optional[str] = content.get("url") + # It may be that "url" is set to either an empty string or None. Accept + # this gracefully, to account for backwards compatible behavior + if extracted_media_id: + try: + mxc_uri = MXCUri.from_str(extracted_media_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"Room avatar MXC Uri ('{extracted_media_id}') is malformed", + Codes.INVALID_PARAM, + ) + # If there is a media item, check for existing restrictions + local_media_data = await self.store.get_local_media( + mxc_uri.media_id + ) + + # Non-existent media is to be handled by the download media endpoint + # so ignore it for now and allow proceeding. It will just not attach + # the media + if local_media_data is not None and local_media_data.restricted: + if not local_media_data.attachments: + if creator.user.to_string() != local_media_data.user_id: + # A different user created a room compared to who + # uploaded the media. Just like with the '/state/' and + # '/send/ endpoints, do not leak the metadata + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media requested for a room avatar is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + # The media is not already attached to anything, proceed + mxc_restrictions = [ + MXCUri(self.server_name, local_media_data.media_id) + ] + else: + # This media is already attached. If it was a prior attempt + # to create a room, the atomic handling of the room creation + # means it will not be attached, so if this exists it + # succeeded somewhere else. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media requested for a room avatar is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + ( new_event, new_unpersisted_context, @@ -1510,6 +1563,7 @@ async def create_event( # state_map since it is modified below. state_map=dict(state_map), for_batch=for_batch, + mxc_restriction_list_for_event=mxc_restrictions, ) depth += 1 diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index d7255dee25..b571453c0d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -4525,8 +4525,8 @@ def prepare( self.user = self.register_user("david", "password") self.tok = self.login("david", "password") - self.other_user = self.register_user("mongo", "password") - self.other_tok = self.login("mongo", "password") + self.other_user = self.register_user("elliott", "password") + self.other_tok = self.login("elliott", "password") def default_config(self) -> JsonDict: config = super().default_config() @@ -4833,8 +4833,8 @@ def prepare( self.user = self.register_user("david", "password") self.tok = self.login("david", "password") - self.other_user = self.register_user("mongo", "password") - self.other_tok = self.login("mongo", "password") + self.other_user = self.register_user("elliott", "password") + self.other_tok = self.login("elliott", "password") def default_config(self) -> JsonDict: config = super().default_config() @@ -5108,3 +5108,160 @@ def test_idempotency_of_attaching_media_to_message_event(self) -> None: # Sort if need to do annotations and reactions and other m.relates_to stuff here + + +class RoomsCreateMediaAttachmentTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + + self.user = self.register_user("david", "password") + self.tok = self.login("david", "password") + + self.other_user = self.register_user("elliott", "password") + self.other_tok = self.login("elliott", "password") + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def create_media_and_set_restricted_flag( + self, user_id: Optional[str] = None + ) -> MXCUri: + """ + Create media without using an endpoint, and set the restricted flag. This will + not add restrictions on its own, as that is the point of this test series + """ + # Allow for testing a different user doing the creation, so can test it errors + # when attaching + if user_id is None: + user_id = self.user + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(user_id), + restricted=True, + ) + ) + return content_uri + + def create_room_with_avatar( + self, + avatar_mxc: Optional[Union[MXCUri, str]] = None, + creating_user: Optional[str] = None, + tok: Optional[str] = None, + expected_code: int = HTTPStatus.OK, + ) -> Optional[str]: + """Create a room with the given avatar""" + initial_state_content = [] + if avatar_mxc == "": + # Simulate that an empty string is passed + initial_state_content.append( + { + "content": { + "url": str(avatar_mxc), + }, + "state_key": "", + "type": EventTypes.RoomAvatar, + } + ) + + elif avatar_mxc is not None: + initial_state_content.append( + { + "content": { + "info": { + "h": 1, + "w": 1, + "mimetype": "image/png", + "size": 67, + }, + "url": str(avatar_mxc), + }, + "state_key": "", + "type": EventTypes.RoomAvatar, + } + ) + return self.helper.create_room_as( + creating_user or self.user, + extra_content={"initial_state": initial_state_content}, + tok=tok or self.tok, + expect_code=expected_code, + ) + + def test_create_room_can_attach_media(self) -> None: + """ + Basic functionality test that restricted media being enabled sets a provide + room avatar as restricted + """ + test_mxc = self.create_media_and_set_restricted_flag() + room_id = self.create_room_with_avatar(avatar_mxc=test_mxc) + assert room_id is not None + + def test_create_room_does_not_error_when_no_avatar(self) -> None: + """ + Test that creating a room with no room avatar does not break when msc3911 is + enabled. Basically making sure while the config is enabled that it doesn't break + """ + room_id = self.create_room_with_avatar(avatar_mxc=None) + assert room_id is not None + + def test_create_room_does_not_error_avatar_initially_set_empty(self) -> None: + """ + Test that creating a room with a room avatar that contains an empty string for + the 'url' gracefully is ignored + """ + room_id = self.create_room_with_avatar(avatar_mxc="") + assert room_id is not None + + def test_create_room_fails_with_invalid_request_wrong_user(self) -> None: + """Test that a room creator must have uploaded the media""" + test_mxc = self.create_media_and_set_restricted_flag(self.other_user) + room_id = self.create_room_with_avatar(avatar_mxc=test_mxc, expected_code=400) + assert room_id is None + + def test_create_room_fails_with_already_attached_media(self) -> None: + """ + Test that a creating a room with an avatar that is already attached somewhere + else fails + """ + test_mxc = self.create_media_and_set_restricted_flag() + self.get_success( + self.store.set_media_restricted_to_event_id( + test_mxc.server_name, test_mxc.media_id, "$junk_event_id" + ) + ) + room_id = self.create_room_with_avatar(avatar_mxc=test_mxc, expected_code=400) + assert room_id is None + + def test_create_room_with_unknown_media_avatar_succeeds(self) -> None: + """Test that room creation with unknown media as room avatar doesn't fail""" + test_mxc = MXCUri.from_str(f"mxc://somewhere.else/{time.time_ns()}") + room_id = self.create_room_with_avatar( + creating_user=self.other_user, + avatar_mxc=test_mxc, + tok=self.other_tok, + ) + assert room_id is not None + + def test_create_room_fails_with_malformed_room_avatar_url(self) -> None: + """Test that a malformed room avatar url fails the room creation""" + room_id = self.create_room_with_avatar(avatar_mxc="junk", expected_code=400) + assert room_id is None From 6196a5e2899cce2fc35dc925ac5805c5ba2ca90b Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Fri, 5 Sep 2025 14:08:08 +0200 Subject: [PATCH 16/35] chore: use default_max_timeout_ms --- synapse/rest/client/media.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 41edfdb33d..c620d9f2a5 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -341,11 +341,6 @@ async def on_POST( Handles copying a media item referenced by server_name and media_id. Returns a new MXC URI for the copied media. """ - max_timeout_ms = parse_integer( - request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS - ) - max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) - requester = await self.auth.get_user_by_req(request) # Optionally parse request body, must be a JSON object, but no required params. @@ -358,9 +353,9 @@ async def on_POST( media_info = await self.media_repo.get_remote_media_info( server_name, media_id, - max_timeout_ms, + DEFAULT_MAX_TIMEOUT_MS, request.getClientAddress().host, - use_federation=True, # Not sure this is correct value for use_federation + use_federation=True, allow_authenticated=True, ) From bc93a341f4398576c3107931c544f54248a7bbae Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Mon, 8 Sep 2025 16:40:43 +0200 Subject: [PATCH 17/35] feat: ap8 expose restrictions over federation --- .../federation/transport/server/federation.py | 1 + synapse/media/_base.py | 5 +- synapse/media/media_repository.py | 98 ++++++- synapse/media/thumbnailer.py | 22 ++ .../databases/main/media_repository.py | 7 + tests/federation/test_federation_media.py | 250 +++++++++++++++++- 6 files changed, 366 insertions(+), 17 deletions(-) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index eb96ff27f9..3ebe2756ea 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -78,6 +78,7 @@ def __init__( ): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_federation_server() + self.enable_restricted_media = hs.config.experimental.msc3911_enabled class FederationSendServlet(BaseFederationServerServlet): diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 29911dab77..b9112ec756 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -309,6 +309,7 @@ async def respond_with_multipart_responder( media_type: str, media_length: Optional[int], upload_name: Optional[str], + json_response: Optional[dict] = None, ) -> None: """ Responds to requests originating from the federation media `/download` endpoint by @@ -362,7 +363,9 @@ def _quote(x: str) -> str: clock, request, media_type, - {}, # Note: if we change this we need to change the returned ETag. + json_response + if json_response + else {}, # Note: if we change this we need to change the returned ETag. disposition, media_length, ) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index c2e04ea767..f3a616964f 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -23,6 +23,7 @@ import logging import os import shutil +from http import HTTPStatus from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -67,7 +68,11 @@ from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia +from synapse.storage.databases.main.media_repository import ( + LocalMedia, + MediaRestrictions, + RemoteMedia, +) from synapse.types import Requester, UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -182,6 +187,7 @@ def __init__(self, hs: "HomeServer"): self.media_upload_limits.sort( key=lambda limit: limit.time_period_ms, reverse=True ) + self.msc3911_enabled = hs.config.experimental.msc3911_enabled def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( @@ -493,19 +499,12 @@ async def get_local_media( if media_info.authenticated: raise NotFoundError() - # MSC3911: If media is restricted but restriction is empty, the media is in - # pending state and only creator can see it until it is attached to an event. - if media_info.restricted: - restrictions = await self.store.get_media_restrictions( - self.server_name, media_info.media_id + restrictions = None + if self.msc3911_enabled: + restrictions = await self.validate_media_restriction( + request, media_info, None, federation ) - if not restrictions: - if not ( - isinstance(request.requester, Requester) - and request.requester.user.to_string() == media_info.user_id - ): - respond_404(request) - return + restrictions_json = restrictions.to_dict() if restrictions else {} self.mark_recently_accessed(None, media_id) @@ -526,7 +525,13 @@ async def get_local_media( responder = await self.media_storage.fetch_media(file_info) if federation: await respond_with_multipart_responder( - self.clock, request, responder, media_type, media_length, upload_name + self.clock, + request, + responder, + media_type, + media_length, + upload_name, + restrictions_json, ) else: await respond_with_responder( @@ -1568,3 +1573,68 @@ async def _remove_local_media_from_disk( removed_media.append(media_id) return removed_media, len(removed_media) + + async def validate_media_restriction( + self, + request: SynapseRequest, + media_info: Optional[LocalMedia], + media_id: Optional[str], + is_federation: bool = False, + ) -> Optional[MediaRestrictions]: + """ + MSC3911: If media is restricted but restriction is empty, the media is in + pending state and only creator can see it until it is attached to an event. If + there is a restriction return MediaRestrictions after validation. + + Args: + request: The incoming request. + media_info: Optional, the media information. + media_id: Optional, the media ID to validate. + + Returns: + MediaRestrictions if there is one set, otherwise raise SynapseError. + """ + if not media_info and media_id: + media_info = await self.store.get_local_media(media_id) + if not media_info: + return None + restricted = media_info.restricted + if not restricted: + return None + attachments: Optional[MediaRestrictions] = media_info.attachments + # for both federation and client endpoints + if attachments: + # Only one of event_id or profile_user_id must be set, not both, not neither + if attachments.event_id is None and attachments.profile_user_id is None: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "MediaRestrictions must have exactly one of event_id or profile_user_id set.", + errcode=Codes.FORBIDDEN, + ) + if bool(attachments.event_id) == bool(attachments.profile_user_id): + raise SynapseError( + HTTPStatus.FORBIDDEN, + "MediaRestrictions must have exactly one of event_id or profile_user_id set.", + errcode=Codes.FORBIDDEN, + ) + + if not attachments and is_federation: + raise SynapseError( + HTTPStatus.NOT_FOUND, + "Not found '%s'" % (request.path.decode(),), + errcode=Codes.NOT_FOUND, + ) + + if not attachments and not is_federation: + if ( + isinstance(request.requester, Requester) + and request.requester.user.to_string() != media_info.user_id + ): + raise SynapseError( + HTTPStatus.NOT_FOUND, + "Not found '%s'" % (request.path.decode(),), + errcode=Codes.NOT_FOUND, + ) + else: + return None + return attachments diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index 5d9afda322..a8bfee77a8 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -270,6 +270,7 @@ def __init__( self.media_storage = media_storage self.store = hs.get_datastores().main self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails + self.msc3911_enabled = hs.config.experimental.msc3911_enabled async def respond_local_thumbnail( self, @@ -289,6 +290,13 @@ async def respond_local_thumbnail( if not media_info: return + restrictions = None + if self.msc3911_enabled: + restrictions = await self.media_repo.validate_media_restriction( + request, media_info, None, for_federation + ) + restrictions_json = restrictions.to_dict() if restrictions else {} + # if the media the thumbnail is generated from is authenticated, don't serve the # thumbnail over an unauthenticated endpoint if self.hs.config.media.enable_authenticated_media and not allow_authenticated: @@ -314,6 +322,7 @@ async def respond_local_thumbnail( server_name=None, for_federation=for_federation, media_info=media_info, + json_response=restrictions_json, ) async def select_or_generate_local_thumbnail( @@ -346,6 +355,13 @@ async def select_or_generate_local_thumbnail( return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) + restrictions = None + if self.msc3911_enabled: + restrictions = await self.media_repo.validate_media_restriction( + request, None, media_id, for_federation + ) + + restrictions_json = restrictions.to_dict() if restrictions else {} for info in thumbnail_infos: t_w = info.width == desired_width t_h = info.height == desired_height @@ -370,8 +386,10 @@ async def select_or_generate_local_thumbnail( info.type, info.length, None, + json_response=restrictions_json, ) return + else: await respond_with_responder( request, responder, info.type, info.length @@ -402,6 +420,7 @@ async def select_or_generate_local_thumbnail( file_info.thumbnail.type, file_info.thumbnail.length, None, + json_response=restrictions_json, ) else: await respond_with_file(self.hs, request, desired_type, file_path) @@ -560,6 +579,7 @@ async def _select_and_respond_with_thumbnail( for_federation: bool, media_info: Optional[LocalMedia] = None, server_name: Optional[str] = None, + json_response: Optional[dict] = None, ) -> None: """ Respond to a request with an appropriate thumbnail from the previously generated thumbnails. @@ -620,6 +640,7 @@ async def _select_and_respond_with_thumbnail( file_info.thumbnail.type, file_info.thumbnail.length, None, + json_response=json_response, ) return else: @@ -679,6 +700,7 @@ async def _select_and_respond_with_thumbnail( file_info.thumbnail.type, file_info.thumbnail.length, None, + json_response=json_response, ) else: await respond_with_responder( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 55dcbd6fe7..9dc4a7fda0 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -74,6 +74,13 @@ class MediaRestrictions: event_id: Optional[str] = None profile_user_id: Optional[UserID] = None + def to_dict(self) -> dict: + if self.event_id: + return {"restrictions": {"event_id": str(self.event_id)}} + if self.profile_user_id: + return {"restrictions": {"profile_user_id": str(self.profile_user_id)}} + return {} + @attr.s(slots=True, frozen=True, auto_attribs=True) class LocalMedia: diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index cd4905239f..a53b6bf496 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -18,6 +18,7 @@ # # import io +import json import os import shutil import tempfile @@ -30,9 +31,11 @@ FileStorageProviderBackend, StorageProviderWrapper, ) +from synapse.rest.client import login from synapse.server import HomeServer -from synapse.types import UserID -from synapse.util import Clock +from synapse.storage.database import LoggingTransaction +from synapse.types import JsonDict, UserID +from synapse.util import Clock, json_encoder from tests import unittest from tests.media.test_media_storage import small_png @@ -187,6 +190,161 @@ def test_federation_etag(self) -> None: self.assertNotIn("body", channel.result) +class FederationRestrictedMediaDownloadsTest(unittest.FederatingHomeserverTestCase): + servlets = [ + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") + self.addCleanup(shutil.rmtree, self.test_dir) + self.primary_base_path = os.path.join(self.test_dir, "primary") + self.secondary_base_path = os.path.join(self.test_dir, "secondary") + hs.config.media.media_store_path = self.primary_base_path + self.store = hs.get_datastores().main + + storage_providers = [ + StorageProviderWrapper( + FileStorageProviderBackend(hs, self.secondary_base_path), + store_local=True, + store_remote=False, + store_synchronous=True, + ) + ] + + self.filepaths = MediaFilePaths(self.primary_base_path) + self.media_storage = MediaStorage( + hs, self.primary_base_path, self.filepaths, storage_providers + ) + self.media_repo = hs.get_media_repository() + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def test_restricted_media_download_with_restrictions_field(self) -> None: + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string("@user_id:something.org"), + restricted=True, + ) + ) + # Attach restrictions to the media + self.get_success( + self.media_repo.store.set_media_restricted_to_event_id( + self.hs.hostname, content_uri.media_id, "random-event-id" + ) + ) + # Send download request with federation endpoint + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/download/{content_uri.media_id}", + ) + self.assertEqual(200, channel.code) + + content_type = channel.headers.getRawHeaders("content-type") + assert content_type is not None + assert "multipart/mixed" in content_type[0] + assert "boundary" in content_type[0] + + boundary = content_type[0].split("boundary=")[1] + body = channel.result.get("body") + assert body is not None + + # Assert a JSON part exists with field restrictions + stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) + json_obj = None + for part in stripped_bytes: + if b"Content-Type: application/json" in part: + idx = part.find(b"\r\n\r\n") + assert idx != -1, "No JSON payload found after header" + json_bytes = part[idx + 4 :].strip() + json_obj = json.loads(json_bytes.decode("utf-8")) + break + + assert json_obj is not None, "No JSON part found" + assert json_obj.get("restrictions", {}).get("event_id") == "random-event-id" + + # Check the png file exists and matches what was uploaded + found_file = any(SMALL_PNG in field for field in stripped_bytes) + self.assertTrue(found_file) + + def test_restricted_media_download_without_restrictions_field_fails(self) -> None: + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + restricted=True, + ) + ) + + # Send download request with federation endpoint + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/download/{content_uri.media_id}", + ) + self.assertEqual(404, channel.code) + self.assertIn(b"Not found", channel.result.get("body", b"")) + + def test_restricted_media_download_with_invalid_restrictions_field_fails( + self, + ) -> None: + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + restricted=True, + ) + ) + # Append invalid restrictions set for test + json_object = {"random_field": "random_value"} + + def insert_restriction(txn: LoggingTransaction) -> None: + self.store.db_pool.simple_insert_txn( + txn, + table="media_attachments", + values={ + "server_name": self.hs.hostname, + "media_id": content_uri.media_id, + "restrictions_json": json_encoder.encode(json_object), + }, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_restricted_media_download_with_invalid_restrictions_field_fails", + insert_restriction, + ) + ) + + # Send download request with federation endpoint + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/download/{content_uri.media_id}", + ) + self.assertEqual(403, channel.code) + self.assertIn( + b"MediaRestrictions must have exactly one of", + channel.result.get("body", b""), + ) + + class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) @@ -293,3 +451,91 @@ def test_thumbnail_download_cropped(self) -> None: small_png.expected_cropped in field for field in stripped_bytes ) self.assertTrue(found_file) + + +class FederationRestrictedThumbnailTest(unittest.FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") + self.addCleanup(shutil.rmtree, self.test_dir) + self.primary_base_path = os.path.join(self.test_dir, "primary") + self.secondary_base_path = os.path.join(self.test_dir, "secondary") + + hs.config.media.media_store_path = self.primary_base_path + + storage_providers = [ + StorageProviderWrapper( + FileStorageProviderBackend(hs, self.secondary_base_path), + store_local=True, + store_remote=False, + store_synchronous=True, + ) + ] + + self.filepaths = MediaFilePaths(self.primary_base_path) + self.media_storage = MediaStorage( + hs, self.primary_base_path, self.filepaths, storage_providers + ) + self.media_repo = hs.get_media_repository() + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def test_restricted_thumbnail_download_with_restrictions_field(self) -> None: + content = io.BytesIO(small_png.data) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_thumbnail", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + restricted=True, + ) + ) + # Attach restrictions to the media + self.get_success( + self.media_repo.store.set_media_restricted_to_user_profile( + self.hs.hostname, content_uri.media_id, "@user_id:whatever.org" + ) + ) + + # Send download request with federation endpoint + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=scale", + ) + self.assertEqual(200, channel.code) + + content_type = channel.headers.getRawHeaders("content-type") + assert content_type is not None + assert "multipart/mixed" in content_type[0] + assert "boundary" in content_type[0] + + boundary = content_type[0].split("boundary=")[1] + body = channel.result.get("body") + assert body is not None + + # Assert a JSON part exists with field restrictions + stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) + json_obj = None + for part in stripped_bytes: + if b"Content-Type: application/json" in part: + idx = part.find(b"\r\n\r\n") + assert idx != -1, "No JSON payload found after header" + json_bytes = part[idx + 4 :].strip() + json_obj = json.loads(json_bytes.decode("utf-8")) + break + + assert json_obj is not None, "No JSON part found" + assert ( + json_obj.get("restrictions", {}).get("profile_user_id") + == "@user_id:whatever.org" + ) + + # Check that the png file exists and matches the expected scaled bytes + found_file = any(small_png.expected_scaled in field for field in stripped_bytes) + self.assertTrue(found_file) From c281798be77f6c054006a8f62e796d8fdda23643 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Wed, 10 Sep 2025 13:55:58 +0200 Subject: [PATCH 18/35] chore: namespace the json response --- synapse/storage/databases/main/media_repository.py | 8 ++++++-- tests/federation/test_federation_media.py | 7 +++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 9dc4a7fda0..0a245dea0f 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -76,9 +76,13 @@ class MediaRestrictions: def to_dict(self) -> dict: if self.event_id: - return {"restrictions": {"event_id": str(self.event_id)}} + return {"org.matrix.msc3911.restrictions": {"event_id": str(self.event_id)}} if self.profile_user_id: - return {"restrictions": {"profile_user_id": str(self.profile_user_id)}} + return { + "org.matrix.msc3911.restrictions": { + "profile_user_id": str(self.profile_user_id) + } + } return {} diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index a53b6bf496..a3e48d89fc 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -271,7 +271,10 @@ def test_restricted_media_download_with_restrictions_field(self) -> None: break assert json_obj is not None, "No JSON part found" - assert json_obj.get("restrictions", {}).get("event_id") == "random-event-id" + assert ( + json_obj.get("org.matrix.msc3911.restrictions", {}).get("event_id") + == "random-event-id" + ) # Check the png file exists and matches what was uploaded found_file = any(SMALL_PNG in field for field in stripped_bytes) @@ -532,7 +535,7 @@ def test_restricted_thumbnail_download_with_restrictions_field(self) -> None: assert json_obj is not None, "No JSON part found" assert ( - json_obj.get("restrictions", {}).get("profile_user_id") + json_obj.get("org.matrix.msc3911.restrictions", {}).get("profile_user_id") == "@user_id:whatever.org" ) From a05b032060177bdba7d1d97b87a922c7658cf500 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Thu, 4 Sep 2025 12:52:25 -0500 Subject: [PATCH 19/35] MSC3911 AP7: Permission checks for download and thumbnail endpoints --- synapse/api/errors.py | 7 + synapse/media/media_repository.py | 247 +++++- synapse/media/thumbnailer.py | 56 +- synapse/rest/client/media.py | 10 +- tests/rest/client/test_media.py | 915 +++++++++++++++++++++- tests/rest/client/test_media_download.py | 346 ++++++++ tests/rest/client/test_media_thumbnail.py | 368 +++++++++ tests/rest/client/utils.py | 37 +- 8 files changed, 1952 insertions(+), 34 deletions(-) create mode 100644 tests/rest/client/test_media_download.py create mode 100644 tests/rest/client/test_media_thumbnail.py diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b832c2f6a1..df63550109 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -245,6 +245,13 @@ def __init__(self, msg: str): super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON) +class UnauthorizedRequestAPICallError(SynapseError): + """Error raised when a request was not allowed due to authorization""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.FORBIDDEN, msg, Codes.UNAUTHORIZED) + + class InvalidProxyCredentialsError(SynapseError): """Error raised when the proxy credentials are invalid.""" diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index f3a616964f..d28319c00f 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -25,7 +25,7 @@ import shutil from http import HTTPStatus from io import BytesIO -from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import attr from matrix_common.types.mxc_uri import MXCUri @@ -33,6 +33,7 @@ import twisted.web.http from twisted.internet.defer import Deferred +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.errors import ( Codes, FederationDeniedError, @@ -40,6 +41,7 @@ NotFoundError, RequestSendFailed, SynapseError, + UnauthorizedRequestAPICallError, cs_error, ) from synapse.api.ratelimiting import Ratelimiter @@ -74,9 +76,17 @@ RemoteMedia, ) from synapse.types import Requester, UserID +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string +from synapse.visibility import ( + _HISTORY_VIS_KEY, + MEMBERSHIP_PRIORITY, + VISIBILITY_PRIORITY, + filter_events_for_client, + get_effective_room_visibility_from_state, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -187,7 +197,8 @@ def __init__(self, hs: "HomeServer"): self.media_upload_limits.sort( key=lambda limit: limit.time_period_ms, reverse=True ) - self.msc3911_enabled = hs.config.experimental.msc3911_enabled + + self.enable_media_restriction = self.hs.config.experimental.msc3911_enabled def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( @@ -466,12 +477,197 @@ async def get_local_media_info( self.respond_not_yet_uploaded(request) return None + async def is_media_visible( + self, requesting_user: UserID, media_info_object: Union[LocalMedia, RemoteMedia] + ) -> None: + """ + Verify that media requested for download should be visible to the user making + the request + """ + + if not self.enable_media_restriction: + return + + if not media_info_object.restricted: + return + + if not media_info_object.attachments: + # When the media has not been attached yet, only the originating user can + # see it. But once attachments have been formed, standard other rules apply + if isinstance(media_info_object, LocalMedia) and ( + requesting_user.to_string() == str(media_info_object.user_id) + ): + return + + # It was restricted, but no attachments. Deny + raise UnauthorizedRequestAPICallError( + f"Media requested ('{media_info_object.media_id}') is restricted" + ) + + attached_event_id = media_info_object.attachments.event_id + attached_profile_user_id = media_info_object.attachments.profile_user_id + + if attached_event_id: + event_base = await self.store.get_event(attached_event_id) + storage_controllers = self.hs.get_storage_controllers() + if event_base.is_state(): + # The standard event visibility utility, filter_events_for_client(), + # does not seem to meet the needs of a good UX when restricting and + # allowing media. This is a very, very simple version to be used for + # state events. + + # First we will collect the current membership of the user for the room + # the relevant event came from. Then we will collect the membership and + # m.room.history_visibility event at the time of the relevant event. + + # Since it is hard to find a relevant place in which to search back in + # time to find out if a given room ever had anything other than a leave + # event, this is the simplest without having to do tablescans + + # Need membership of NOW + ( + membership_now, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + requesting_user.to_string(), event_base.room_id + ) + + if not membership_now: + membership_now = Membership.LEAVE + + membership_state_key = (EventTypes.Member, requesting_user.to_string()) + types = (_HISTORY_VIS_KEY, membership_state_key) + # and history visibility and membership of THEN + event_id_to_state = ( + await storage_controllers.state.get_state_for_events( + [attached_event_id], + state_filter=StateFilter.from_types(types), + ) + ) + + state_map = event_id_to_state.get(attached_event_id) + # Do we need to guard against not having state of a room? + assert state_map is not None + + visibility = get_effective_room_visibility_from_state(state_map) + + memb_then_evt = state_map.get(membership_state_key) + membership_then = Membership.LEAVE + if memb_then_evt: + membership_then = memb_then_evt.content.get( + "membership", Membership.LEAVE + ) + + # Have a few numbers ready for comparison below. These resolve to int + # The index of the visibility present from the event + visibility_priority = VISIBILITY_PRIORITY.index(visibility) + membership_priority_now = MEMBERSHIP_PRIORITY.index(membership_now) + membership_priority_then = MEMBERSHIP_PRIORITY.index(membership_then) + + # These are essentially constants, in that they should not change + world_readable_index = VISIBILITY_PRIORITY.index( + HistoryVisibility.WORLD_READABLE + ) + shared_visibility_index = VISIBILITY_PRIORITY.index( + HistoryVisibility.SHARED + ) + mem_leave_index = MEMBERSHIP_PRIORITY.index(Membership.LEAVE) + + # I disagree with this. 'Shared' by spec implies that some sort of + # positive membership event took place, but the stock + # filter_events_for_client() seems to treat SHARED like WORLD_READABLE, + # so at least this matches + if visibility_priority in [ + world_readable_index, + shared_visibility_index, + ]: + # world readable should always be seen + return + + # If the room is invite visible, and the user is invited, move on + if visibility_priority == VISIBILITY_PRIORITY.index( + HistoryVisibility.INVITED + ) and membership_priority_now == MEMBERSHIP_PRIORITY.index( + Membership.INVITE + ): + return + + # The visibility of the room is shared or greater, so requires at + # the minimum a 'knock' level. Make sure the membership of the user + # is better than leave + if ( + visibility_priority >= shared_visibility_index + and membership_priority_now < mem_leave_index + ): + return + + # Cover the case that a user has left a room but still should see any + # media they were allowed to see prior + # The visibility of the room is shared or greater, so requires at + # the minimum a 'knock' level. Make sure the membership of the user + # is better than leave + if ( + visibility_priority >= shared_visibility_index + and membership_priority_then < mem_leave_index + ): + return + + else: + filtered_events = await filter_events_for_client( + storage_controllers, + requesting_user.to_string(), + [event_base], + ) + if len(filtered_events) > 0: + return + + elif attached_profile_user_id: + # Can this user see that profile? + + # The error returns here may not be suitable, use the work around below + # If shared room restricted profile lookups, it will be restricted + # to users that share rooms + # await self.profile_handler.check_profile_query_allowed( + # restrictions.profile_user_id, requester.user + # ) + # return + + if self.hs.config.server.limit_profile_requests_to_users_who_share_rooms: + # First take care of the case where the requesting user IS the creating + # user. The other function below does not handle this. + if requesting_user.to_string() == attached_profile_user_id.to_string(): + return + + # This call returns a set() that contains which of the "other_user_ids" + # share a room. Since we give it only one, if bool(set()) is True, then they + # share some room or had at least one invite between them. + if not await self.store.do_users_share_a_room_joined_or_invited( + requesting_user.to_string(), + [attached_profile_user_id.to_string()], + ): + raise UnauthorizedRequestAPICallError( + f"Media requested ('{media_info_object.media_id}') is restricted" + ) + + # check these settings: + # * allow_profile_lookup_over_federation + + # If 'limit_profile_requests_to_users_who_share_rooms' is not enabled, all + # bets are kinda off + return + + # It was a third unknown restriction, or otherwise did not pass inspection + raise UnauthorizedRequestAPICallError( + f"Media requested ('{media_info_object.media_id}') is restricted" + ) + async def get_local_media( self, request: SynapseRequest, media_id: str, name: Optional[str], max_timeout_ms: int, + requester: Optional[Requester] = None, allow_authenticated: bool = True, federation: bool = False, ) -> None: @@ -485,6 +681,8 @@ async def get_local_media( the filename in the Content-Disposition header of the response. max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. + requester: The user making the request, to verify restricted media. Only + used for local users, not over federation allow_authenticated: whether media marked as authenticated may be served to this request federation: whether the local media being fetched is for a federation request @@ -500,7 +698,13 @@ async def get_local_media( raise NotFoundError() restrictions = None - if self.msc3911_enabled: + # if MSC3911 is enabled, check visibility of the media for the user and retrieve + # any restrictions + if self.enable_media_restriction: + if requester is not None: + # Only check media visibility if this is for a local request. This will + # raise directly back to the client if not visible + await self.is_media_visible(requester.user, media_info) restrictions = await self.validate_media_restriction( request, media_info, None, federation ) @@ -547,6 +751,7 @@ async def get_remote_media( max_timeout_ms: int, ip_address: str, use_federation_endpoint: bool, + requester: Optional[Requester] = None, allow_authenticated: bool = True, ) -> None: """Respond to requests for remote media. @@ -562,6 +767,8 @@ async def get_remote_media( ip_address: the IP address of the requester use_federation_endpoint: whether to request the remote media over the new federation `/download` endpoint + requester: The user making the request, to verify restricted media. Only + used for local users, not over federation allow_authenticated: whether media marked as authenticated may be served to this request @@ -596,6 +803,7 @@ async def get_remote_media( ip_address, use_federation_endpoint, allow_authenticated, + requester, ) # Check if the media is cached on the client, if so return 304. We need @@ -630,6 +838,7 @@ async def get_remote_media_info( ip_address: str, use_federation: bool, allow_authenticated: bool, + requester: Optional[Requester] = None, ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. @@ -644,6 +853,8 @@ async def get_remote_media_info( over the federation `/download` endpoint allow_authenticated: whether media marked as authenticated may be served to this request + requester: The user making the request, to verify restricted media. Only + used for local users, not over federation Returns: The media info of the file @@ -666,6 +877,7 @@ async def get_remote_media_info( ip_address, use_federation, allow_authenticated, + requester, ) # Ensure we actually use the responder so that it releases resources @@ -684,6 +896,7 @@ async def _get_remote_media_impl( ip_address: str, use_federation_endpoint: bool, allow_authenticated: bool, + requester: Optional[Requester] = None, ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -699,6 +912,9 @@ async def _get_remote_media_impl( ip_address: the IP address of the requester use_federation_endpoint: whether to request the remote media over the new federation /download endpoint + allow_authenticated: + requester: The user making the request, to verify restricted media. Only + used for local users, not over federation Returns: A tuple of responder and the media info of the file. @@ -710,12 +926,19 @@ async def _get_remote_media_impl( if not media_info or media_info.authenticated: raise NotFoundError() - # file_id is the ID we use to track the file locally. If we've already - # seen the file then reuse the existing ID, otherwise generate a new - # one. - # If we have an entry in the DB, try and look for it if media_info: + # if MSC3911 is enabled, check visibility of the media for the user. This + # check exists twice in this function, once up here for when it already + # exists in the local database and again further down for after it was + # retrieved from the remote. + if self.enable_media_restriction and requester is not None: + # This will raise directly back to the client if not visible + await self.is_media_visible(requester.user, media_info) + + # file_id is the ID we use to track the file locally. If we've already + # seen the file then reuse the existing ID, otherwise generate a new + # one. file_id = media_info.filesystem_id file_info = FileInfo(server_name, file_id) @@ -761,6 +984,16 @@ async def _get_remote_media_impl( if not media_info: raise e + # if MSC3911 is enabled, check visibility of the media for the user. + # Restricted media requires authentication to be enabled + if ( + self.hs.config.media.enable_authenticated_media + and self.enable_media_restriction + and requester is not None + ): + # This will raise directly back to the client if not visible + await self.is_media_visible(requester.user, media_info) + file_id = media_info.filesystem_id if not media_info.media_type: media_info = attr.evolve(media_info, media_type="application/octet-stream") diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index a8bfee77a8..0b2b37d808 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -42,6 +42,7 @@ ) from synapse.media.media_storage import FileResponder, MediaStorage from synapse.storage.databases.main.media_repository import LocalMedia +from synapse.types import Requester if TYPE_CHECKING: from synapse.media.media_repository import MediaRepository @@ -270,7 +271,7 @@ def __init__( self.media_storage = media_storage self.store = hs.get_datastores().main self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails - self.msc3911_enabled = hs.config.experimental.msc3911_enabled + self.enable_media_restriction = self.hs.config.experimental.msc3911_enabled async def respond_local_thumbnail( self, @@ -282,6 +283,7 @@ async def respond_local_thumbnail( m_type: str, max_timeout_ms: int, for_federation: bool, + requester: Optional[Requester] = None, allow_authenticated: bool = True, ) -> None: media_info = await self.media_repo.get_local_media_info( @@ -290,19 +292,25 @@ async def respond_local_thumbnail( if not media_info: return - restrictions = None - if self.msc3911_enabled: - restrictions = await self.media_repo.validate_media_restriction( - request, media_info, None, for_federation - ) - restrictions_json = restrictions.to_dict() if restrictions else {} - # if the media the thumbnail is generated from is authenticated, don't serve the # thumbnail over an unauthenticated endpoint if self.hs.config.media.enable_authenticated_media and not allow_authenticated: if media_info.authenticated: raise NotFoundError() + # if MSC3911 is enabled, check visibility of the media for the user and retrieve + # any restrictions + restrictions = None + if self.enable_media_restriction: + if requester is not None: + # Only check media visibility if this is for a local request. This will + # raise directly back to the client if not visible + await self.media_repo.is_media_visible(requester.user, media_info) + restrictions = await self.media_repo.validate_media_restriction( + request, media_info, None, for_federation + ) + restrictions_json = restrictions.to_dict() if restrictions else {} + # Once we've checked auth we can return early if the media is cached on # the client if check_for_cached_entry_and_respond(request): @@ -335,6 +343,7 @@ async def select_or_generate_local_thumbnail( desired_type: str, max_timeout_ms: int, for_federation: bool, + requester: Optional[Requester] = None, allow_authenticated: bool = True, ) -> None: media_info = await self.media_repo.get_local_media_info( @@ -349,17 +358,24 @@ async def select_or_generate_local_thumbnail( if media_info.authenticated: raise NotFoundError() + # if MSC3911 is enabled, check visibility of the media for the user and retrieve + # any restrictions + restrictions = None + if self.enable_media_restriction: + if requester is not None: + # Only check media visibility if this is for a local request. This will + # raise directly back to the client if not visible + await self.media_repo.is_media_visible(requester.user, media_info) + restrictions = await self.media_repo.validate_media_restriction( + request, None, media_id, for_federation + ) + # Once we've checked auth we can return early if the media is cached on # the client if check_for_cached_entry_and_respond(request): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - restrictions = None - if self.msc3911_enabled: - restrictions = await self.media_repo.validate_media_restriction( - request, None, media_id, for_federation - ) restrictions_json = restrictions.to_dict() if restrictions else {} for info in thumbnail_infos: @@ -440,6 +456,7 @@ async def select_or_generate_remote_thumbnail( max_timeout_ms: int, ip_address: str, use_federation: bool, + requester: Optional[Requester] = None, allow_authenticated: bool = True, ) -> None: media_info = await self.media_repo.get_remote_media_info( @@ -449,6 +466,7 @@ async def select_or_generate_remote_thumbnail( ip_address, use_federation, allow_authenticated, + requester, ) if not media_info: respond_404(request) @@ -461,6 +479,11 @@ async def select_or_generate_remote_thumbnail( respond_404(request) return + # if MSC3911 is enabled, check visibility of the media for the user + if self.enable_media_restriction and requester is not None: + # This will raise directly back to the client if not visible + await self.media_repo.is_media_visible(requester.user, media_info) + # Check if the media is cached on the client, if so return 304. if check_for_cached_entry_and_respond(request): return @@ -522,6 +545,7 @@ async def respond_remote_thumbnail( max_timeout_ms: int, ip_address: str, use_federation: bool, + requester: Optional[Requester] = None, allow_authenticated: bool = True, ) -> None: # TODO: Don't download the whole remote file @@ -534,6 +558,7 @@ async def respond_remote_thumbnail( ip_address, use_federation, allow_authenticated, + requester, ) if not media_info: return @@ -544,6 +569,11 @@ async def respond_remote_thumbnail( if media_info.authenticated: raise NotFoundError() + # if MSC3911 is enabled, check visibility of the media for the user + if self.enable_media_restriction and requester is not None: + # This will raise directly back to the client if not visible + await self.media_repo.is_media_visible(requester.user, media_info) + # Check if the media is cached on the client, if so return 304. if check_for_cached_entry_and_respond(request): return diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index c620d9f2a5..80b9b8c65b 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -159,7 +159,7 @@ async def on_GET( ) -> None: # Validate the server name, raising if invalid parse_and_validate_server_name(server_name) - await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_cors_headers(request) set_corp_headers(request) @@ -184,6 +184,7 @@ async def on_GET( m_type, max_timeout_ms, False, + requester, ) else: await self.thumbnailer.respond_local_thumbnail( @@ -195,6 +196,7 @@ async def on_GET( m_type, max_timeout_ms, False, + requester, ) self.media_repo.mark_recently_accessed(None, media_id) else: @@ -223,6 +225,7 @@ async def on_GET( max_timeout_ms, ip_address, True, + requester, ) self.media_repo.mark_recently_accessed(server_name, media_id) @@ -250,7 +253,7 @@ async def on_GET( # Validate the server name, raising if invalid parse_and_validate_server_name(server_name) - await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_cors_headers(request) set_corp_headers(request) @@ -274,7 +277,7 @@ async def on_GET( if self._is_mine_server_name(server_name): await self.media_repo.get_local_media( - request, media_id, file_name, max_timeout_ms + request, media_id, file_name, max_timeout_ms, requester ) else: ip_address = request.getClientAddress().host @@ -286,6 +289,7 @@ async def on_GET( max_timeout_ms, ip_address, True, + requester, ) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index bffbd982ef..1b1cf7641a 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -24,11 +24,15 @@ import os import re import shutil +import time +from contextlib import nullcontext +from http import HTTPStatus from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type from unittest.mock import MagicMock, Mock, patch from urllib import parse from urllib.parse import quote, urlencode +from matrix_common.types.mxc_uri import MXCUri from parameterized import parameterized, parameterized_class from PIL import Image as Image @@ -44,7 +48,12 @@ from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource -from synapse.api.errors import Codes, HttpResponseException +from synapse.api.constants import EventTypes, HistoryVisibility, Membership +from synapse.api.errors import ( + Codes, + HttpResponseException, + UnauthorizedRequestAPICallError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.config.oembed import OEmbedEndpointConfig from synapse.http.client import MultipartResponse @@ -54,10 +63,11 @@ from synapse.media.thumbnailer import ThumbnailProvider from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.rest import admin -from synapse.rest.client import login, media +from synapse.rest.client import login, media, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID -from synapse.util import Clock +from synapse.storage.databases.main.media_repository import LocalMedia +from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock, json_encoder from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest @@ -3115,7 +3125,7 @@ def test_upload_restricted_resource(self) -> None: f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", access_token=self.creator_tok, ) - assert channel.code == 200 + assert channel.code == 200, channel.json_body # The other user cannot download the restricted resource in pending state. channel = self.make_request( @@ -3123,7 +3133,7 @@ def test_upload_restricted_resource(self) -> None: f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", access_token=self.other_user_tok, ) - assert channel.code == 404 + assert channel.code == 403, channel.json_body def test_async_upload_restricted_resource(self) -> None: """ @@ -3167,7 +3177,7 @@ def test_async_upload_restricted_resource(self) -> None: f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", access_token=self.creator_tok, ) - assert channel.code == 200 + assert channel.code == 200, channel.json_body # The other user cannot download the restricted resource. channel = self.make_request( @@ -3175,7 +3185,7 @@ def test_async_upload_restricted_resource(self) -> None: f"/_matrix/client/v1/media/download/{self.hs.hostname}/{media_id}", access_token=self.other_user_tok, ) - assert channel.code == 404 + assert channel.code == 403, channel.json_body class CopyRestrictedResource(unittest.HomeserverTestCase): @@ -3319,3 +3329,892 @@ def test_copy_remote_restricted_resource(self) -> None: # Check if copied media is unattached to any event or user profile yet. assert copied_media.attachments is None + + +class RestrictedMediaVisibilityTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + self.profile_handler = self.hs.get_profile_handler() + + self.alice_user_id = self.register_user("alice", "password") + self.alice_tok = self.login("alice", "password") + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + # The old endpoints are not loaded with the register_servlets above + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def create_restricted_media(self, user: Optional[str] = None) -> MXCUri: + mxc_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(user or self.alice_user_id), + restricted=True, + ) + ) + return mxc_uri + + def retrieve_media_from_store(self, mxc_uri: MXCUri) -> LocalMedia: + local_media_object = self.get_success( + self.store.get_local_media(mxc_uri.media_id) + ) + assert local_media_object is not None + return local_media_object + + def create_test_room( + self, + visibility: str, + initial_room_avatar_mxc: Optional[MXCUri] = None, + invite_list: Optional[List[str]] = None, + ) -> str: + initial_state_list = [ + { + "type": EventTypes.RoomHistoryVisibility, + "state_key": "", + "content": {"history_visibility": visibility}, + } + ] + + if initial_room_avatar_mxc: + # Simulate the same situation as `create_room()`, placing the room avatar + # after the history visibility event. + initial_state_list.append( + { + "type": EventTypes.RoomAvatar, + "state_key": "", + "content": { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(initial_room_avatar_mxc), + }, + } + ) + extra_content: Dict[str, Any] = { + "initial_state": initial_state_list, + } + + if invite_list: + extra_content.update({"invite": invite_list}) + + room_id = self.helper.create_room_as( + self.alice_user_id, + is_public=True, + tok=self.alice_tok, + extra_content=extra_content, + ) + assert room_id is not None + return room_id + + def send_event_with_attached_media( + self, room_id: str, mxc_uri: MXCUri, tok: Optional[str] = None + ) -> FakeChannel: + txn_id = "m%s" % (str(time.time())) + + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{txn_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=tok or self.alice_tok, + ) + return channel1 + + def insert_message_with_attached_media(self, room_id: str) -> LocalMedia: + mxc_uri = self.create_restricted_media() + + channel = self.send_event_with_attached_media(room_id, mxc_uri) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + assert "event_id" in channel.json_body + event_id = channel.json_body["event_id"] + + # check restrictions are applied + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + # Retrieve the media object for inspection in the test + local_media_object = self.retrieve_media_from_store(mxc_uri) + + return local_media_object + + def assert_expected_result( + self, target_user: UserID, media_object: LocalMedia, expected_bool: bool + ) -> None: + # If the expectation is True, it is expected the function is a success + # If the expectation is False, it is expected to raise the exception + # + # It may be easier to think of True and False here as Visible and Not Visible in + # the context of this TestCase. + if expected_bool: + # As a note about nullcontext manager: Support for async context behavior + # was added in python 3.10. However, we don't need that even when testing an + # async function, as it is wrapped into a sync function('get_success()' and + # 'get_success_or_raise()') to pump the reactor. + maybe_assert_exception = nullcontext() + else: + maybe_assert_exception = self.assertRaises(UnauthorizedRequestAPICallError) + with maybe_assert_exception: + self.get_success_or_raise( + self.media_repo.is_media_visible(target_user, media_object) + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True), + (HistoryVisibility.SHARED, True), + (HistoryVisibility.JOINED, False), + (HistoryVisibility.INVITED, False), + ] + ) + def test_message_media_visibility_to_users_not_sharing_room( + self, visibility: str, expected_result_bool: bool + ) -> None: + """ + Test that a message with restricted media is not visible if appropriate to + someone not in that room + """ + + second_user = self.register_user("second_user_message_test", "password") + second_user_id = UserID.from_string(second_user) + self.login("second_user_message_test", "password") + + room_id = self.create_test_room( + visibility, + ) + + local_media_object = self.insert_message_with_attached_media(room_id) + + self.assert_expected_result( + second_user_id, local_media_object, expected_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + (HistoryVisibility.INVITED, False, True), + (HistoryVisibility.JOINED, False, False), + ] + ) + def test_message_media_visibility_to_users_invited_to_room( + self, + visibility: str, + expected_before_result_bool: bool, + expected_after_result_bool: bool, + ) -> None: + """ + Test that a message with restricted media is visible if appropriate to someone + invited to a room + """ + + invited_user = self.register_user("invited_user_message_test", "password") + invited_user_id = UserID.from_string(invited_user) + self.login("invited_user_message_test", "password") + + room_id = self.create_test_room( + visibility, + ) + + # We will need two pieces of media: one to check for before the invite and one + # for after. + + # Send the first attached media message + first_media_object = self.insert_message_with_attached_media(room_id) + + self.assert_expected_result( + invited_user_id, first_media_object, expected_before_result_bool + ) + + # Good, now do the invite and check the second expectation + self.helper.invite( + room_id, self.alice_user_id, invited_user, tok=self.alice_tok + ) + + # Send the second attached media message + second_media_object = self.insert_message_with_attached_media(room_id) + + self.assert_expected_result( + invited_user_id, second_media_object, expected_after_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + (HistoryVisibility.INVITED, False, True), + (HistoryVisibility.JOINED, False, True), + ] + ) + def test_message_media_visibility_to_users_joined_to_room( + self, + visibility: str, + expected_before_result_bool: bool, + expected_after_result_bool: bool, + ) -> None: + """ + Test that a message with restricted media is visible if appropriate to someone + joined to a room + """ + + joining_user = self.register_user("joining_user_message_test", "password") + joining_user_id = UserID.from_string(joining_user) + joining_user_tok = self.login("joining_user_message_test", "password") + + room_id = self.create_test_room( + visibility, + ) + + # We will need two pieces of media: one to check for before the join and one + # for after. + + # Send the first attached media message + first_media_object = self.insert_message_with_attached_media(room_id) + + self.assert_expected_result( + joining_user_id, first_media_object, expected_before_result_bool + ) + + # Good, now do the join and check the second expectation + self.helper.join(room_id, joining_user, tok=joining_user_tok) + + # Send the second attached media message + second_media_object = self.insert_message_with_attached_media(room_id) + + self.assert_expected_result( + joining_user_id, second_media_object, expected_after_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + (HistoryVisibility.INVITED, True, False), + (HistoryVisibility.JOINED, True, False), + ] + ) + def test_message_media_visibility_to_users_that_left_a_room( + self, + visibility: str, + expected_before_result_bool: bool, + expected_after_result_bool: bool, + ) -> None: + """ + Test that a message with restricted media is visible if appropriate to someone + that left a room + """ + # make another user for this test only. This user will be invited to the room, + # but will not actually join. + leaving_user = self.register_user("leaving_user_message_test", "password") + leaving_user_id = UserID.from_string(leaving_user) + leaving_user_tok = self.login("leaving_user_message_test", "password") + + room_id = self.create_test_room( + visibility, + ) + # Join the user, or else they can not leave + self.helper.join(room_id, leaving_user, tok=leaving_user_tok) + + # We will need two pieces of media: one to check for before the leave and one + # for after. + + first_media_object = self.insert_message_with_attached_media(room_id) + + self.assert_expected_result( + leaving_user_id, first_media_object, expected_before_result_bool + ) + + self.helper.leave(room_id, leaving_user, tok=leaving_user_tok) + + # Now that the user has left the room, make sure they can not see anything after + second_media_object = self.insert_message_with_attached_media(room_id) + + self.assert_expected_result( + leaving_user_id, second_media_object, expected_after_result_bool + ) + + @parameterized.expand( + [ + HistoryVisibility.WORLD_READABLE, + HistoryVisibility.SHARED, + HistoryVisibility.INVITED, + HistoryVisibility.JOINED, + ] + ) + def test_message_media_visibility_for_unknown_restriction( + self, + visibility: str, + ) -> None: + """ + Test that a message with restricted media is not visible if an unknown restriction exists + """ + # Borrow the test setup for joining a room + joining_user = self.register_user("joining_user_message_test", "password") + joining_user_id = UserID.from_string(joining_user) + joining_user_tok = self.login("joining_user_message_test", "password") + + room_id = self.create_test_room( + visibility, + ) + # We'll go ahead and join the second user to the room, so the visibility of a + # non-originating user can check the media's visibility + self.helper.join(room_id, joining_user, tok=joining_user_tok) + + # First, create our piece of media and label it as restricted + mxc_uri = self.create_restricted_media() + + # Then add the database row that would normally attach this media to something, + # but it is an unknown something. + self.get_success( + self.store.db_pool.simple_insert( + "media_attachments", + { + "server_name": mxc_uri.server_name, + "media_id": mxc_uri.media_id, + # "restrictions_json" is a JSONB column, so it expects a string + "restrictions_json": json_encoder.encode({"restrictions": {}}), + }, + ) + ) + + # Retrieve the media info from the data store + local_media_object = self.retrieve_media_from_store(mxc_uri) + + assert local_media_object.restricted is True + # If attachments was None, we would know it had not been attached yet + assert local_media_object.attachments is not None + + # Neither user should be able to see the media + self.assert_expected_result(joining_user_id, local_media_object, False) + self.assert_expected_result(joining_user_id, local_media_object, False) + + def send_membership_with_attached_media( + self, room_id: str, mxc_uri: MXCUri, tok: Optional[str] = None + ) -> FakeChannel: + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.member/{self.alice_user_id}?org.matrix.msc3911.attach_media={str(mxc_uri)}", + { + "membership": Membership.JOIN, + "avatar_url": str(mxc_uri), + }, + access_token=tok or self.alice_tok, + ) + return channel1 + + def insert_membership_with_attached_media(self, room_id: str) -> LocalMedia: + mxc_uri = self.create_restricted_media() + + channel = self.send_membership_with_attached_media(room_id, mxc_uri) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + assert "event_id" in channel.json_body + event_id = channel.json_body["event_id"] + + # check restrictions are applied + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + # Retrieve the media object for inspection in the test + local_media_object = self.retrieve_media_from_store(mxc_uri) + + return local_media_object + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True), + (HistoryVisibility.SHARED, True), + (HistoryVisibility.JOINED, False), + (HistoryVisibility.INVITED, False), + ] + ) + def test_membership_avatar_media_visibility_to_users_not_sharing_room( + self, visibility: str, expected_result_bool: bool + ) -> None: + """ + Test that a user can not see another user's membership-based avatar if they are + not in that room + """ + + second_user = self.register_user("second_user_membership_test", "password") + second_user_id = UserID.from_string(second_user) + self.login("second_user_membership_test", "password") + + room_id = self.create_test_room( + visibility, + ) + + local_media_object = self.insert_membership_with_attached_media(room_id) + + self.assert_expected_result( + second_user_id, local_media_object, expected_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + # Because invites are expected to see member's avatars + (HistoryVisibility.INVITED, False, True), + (HistoryVisibility.JOINED, False, True), + ] + ) + def test_membership_avatar_media_visibility_to_users_invited_to_room( + self, + visibility: str, + expected_first_result_bool: bool, + expected_second_result_bool: bool, + ) -> None: + """ + Test that a user invited to a room can see the membership-based avatar of the + inviting user in that room. Do not actually join the room + + This test has two parts: + First test an invite sent after setting the alice's membership avatar + + Second, test an invite being sent as part of the room creation, before the + membership-based avatar is established + + We will use the `expected_after_result_bool` twice, once for each scenario + """ + + invited_user = self.register_user("invited_user_membership_test", "password") + invited_user_id = UserID.from_string(invited_user) + self.login("invited_user_membership_test", "password") + + # Test scenario 1 + room_id = self.create_test_room( + visibility, + ) + + # Set alice's membership avatar to a specific image + first_media_object = self.insert_membership_with_attached_media(room_id) + + self.assert_expected_result( + invited_user_id, first_media_object, expected_first_result_bool + ) + + # Good, now do the invite and check the second expectation. + self.helper.invite( + room_id, self.alice_user_id, invited_user, tok=self.alice_tok + ) + + self.assert_expected_result( + invited_user_id, first_media_object, expected_second_result_bool + ) + + # Test scenario 2, invite as part of room creation + room_2_id = self.create_test_room(visibility, invite_list=[invited_user]) + + # Set alice's membership avatar to a specific image in the second room + second_media_object = self.insert_membership_with_attached_media(room_2_id) + + # As it turns out, this should be the same result as the second expectation + # above, so just reuse it + self.assert_expected_result( + invited_user_id, second_media_object, expected_second_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + (HistoryVisibility.INVITED, False, True), + (HistoryVisibility.JOINED, False, True), + ] + ) + def test_membership_avatar_media_visibility_to_users_joined_to_room( + self, + visibility: str, + expected_first_result_bool: bool, + expected_second_result_bool: bool, + ) -> None: + """ + Test that a user joined to a room can see the membership-based avatar of another user + """ + + joining_user = self.register_user("joining_user_membership_test", "password") + joining_user_id = UserID.from_string(joining_user) + joining_user_tok = self.login("joining_user_membership_test", "password") + + room_id = self.create_test_room( + visibility, + ) + + # We will need two pieces of media: one to check for before the join and one + # for after. + + first_media_object = self.insert_membership_with_attached_media(room_id) + + self.assert_expected_result( + joining_user_id, first_media_object, expected_first_result_bool + ) + + # Then join the room + self.helper.join(room_id, joining_user, tok=joining_user_tok) + + # Can the user see it now? + self.assert_expected_result( + joining_user_id, first_media_object, expected_second_result_bool + ) + + # Test with a second piece of media that was created after the join + second_media_object = self.insert_membership_with_attached_media(room_id) + + self.assert_expected_result( + joining_user_id, second_media_object, expected_second_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + (HistoryVisibility.INVITED, True, False), + (HistoryVisibility.JOINED, True, False), + ] + ) + def test_membership_avatar_media_visibility_to_users_that_left_room( + self, + visibility: str, + expected_first_result_bool: bool, + expected_second_result_bool: bool, + ) -> None: + """ + Test that a user leaving a room can not see another user's avatar from that room + after they have left + """ + + leaving_user = self.register_user("leaving_user_membership_test", "password") + leaving_user_id = UserID.from_string(leaving_user) + leaving_user_tok = self.login("leaving_user_membership_test", "password") + + room_id = self.create_test_room( + visibility, + ) + # Join the user, or else they can not leave + self.helper.join(room_id, leaving_user, tok=leaving_user_tok) + + # We will need two pieces of media: one to check for before the join and one + # for after. + + first_media_object = self.insert_membership_with_attached_media(room_id) + + # This should always succeed, we are in the room after all + self.assert_expected_result(leaving_user_id, first_media_object, True) + + # Time to leave + self.helper.leave(room_id, leaving_user, tok=leaving_user_tok) + + # Recheck the first media object + self.assert_expected_result( + leaving_user_id, first_media_object, expected_first_result_bool + ) + # Make another one, to make sure it behaves correctly + second_media_object = self.insert_membership_with_attached_media(room_id) + + self.assert_expected_result( + leaving_user_id, second_media_object, expected_second_result_bool + ) + + def send_room_avatar_with_attached_media( + self, room_id: str, mxc_uri: MXCUri, tok: Optional[str] = None + ) -> FakeChannel: + channel1 = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.avatar?org.matrix.msc3911.attach_media={str(mxc_uri)}", + { + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "url": str(mxc_uri), + }, + access_token=tok or self.alice_tok, + ) + return channel1 + + def insert_room_avatar_with_attached_media(self, room_id: str) -> LocalMedia: + mxc_uri = self.create_restricted_media() + + channel = self.send_room_avatar_with_attached_media(room_id, mxc_uri) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + assert "event_id" in channel.json_body + event_id = channel.json_body["event_id"] + + # check restrictions are applied + restrictions = self.get_success( + self.store.get_media_restrictions(mxc_uri.server_name, mxc_uri.media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + assert restrictions.profile_user_id is None + + # Retrieve the media object for inspection in the test + local_media_object = self.retrieve_media_from_store(mxc_uri) + + return local_media_object + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + (HistoryVisibility.JOINED, False, False), + (HistoryVisibility.INVITED, False, False), + ] + ) + def test_room_avatar_media_visibility_to_users_not_sharing_room( + self, + visibility: str, + expected_room_creation_result_bool: bool, + expected_result_bool: bool, + ) -> None: + """ + Test that a user can not see the restricted media room avatar. + + Specifically, test the avatar set as part of room creation and changed/set after + """ + + second_user = self.register_user("second_user_room_avatar_test", "password") + second_user_id = UserID.from_string(second_user) + self.login("second_user_room_avatar_test", "password") + + # The room needs to be created with an avatar + room_creation_avatar_mxc = self.create_restricted_media() + room_id = self.create_test_room(visibility, room_creation_avatar_mxc) + + room_avatar_object = self.retrieve_media_from_store(room_creation_avatar_mxc) + self.assert_expected_result( + second_user_id, room_avatar_object, expected_room_creation_result_bool + ) + + # Let's change the avatar, and make sure it acts appropriately + local_media_object = self.insert_room_avatar_with_attached_media(room_id) + + self.assert_expected_result( + second_user_id, local_media_object, expected_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + # Because invites are expected to see room avatars + (HistoryVisibility.INVITED, False, True), + (HistoryVisibility.JOINED, False, True), + ] + ) + def test_room_avatar_media_visibility_to_users_invited_to_room( + self, + visibility: str, + expected_before_result_bool: bool, + expected_after_result_bool: bool, + ) -> None: + """ + Test that a user invited to a room can see the avatar of that room + """ + + invited_user = self.register_user("invited_user_room_avatar_test", "password") + invited_user_id = UserID.from_string(invited_user) + self.login("invited_user_room_avatar_test", "password") + + creation_room_avatar_mxc = self.create_restricted_media() + room_id = self.create_test_room(visibility, creation_room_avatar_mxc) + + # Retrieve the media info for the room avatar. This will have been set after the + # visibility event was created and is therefor governed by it. + first_media_object = self.retrieve_media_from_store(creation_room_avatar_mxc) + + self.assert_expected_result( + invited_user_id, first_media_object, expected_before_result_bool + ) + + # Good, now do the invite + self.helper.invite( + room_id, self.alice_user_id, invited_user, tok=self.alice_tok + ) + + # Recheck the room avatar made during room creation + self.assert_expected_result( + invited_user_id, first_media_object, expected_after_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + (HistoryVisibility.INVITED, False, True), + (HistoryVisibility.JOINED, False, True), + ] + ) + def test_room_avatar_media_visibility_to_users_joined_to_room( + self, + visibility: str, + expected_before_result_bool: bool, + expected_after_result_bool: bool, + ) -> None: + """ + Test that a user joined to a room can see the avatar of that room + """ + + joining_user = self.register_user("joining_user_room_avatar_test", "password") + joining_user_id = UserID.from_string(joining_user) + joining_user_tok = self.login("joining_user_room_avatar_test", "password") + + # We will need two pieces of media: one to check for before the join and one + # for after. + + creation_room_avatar_mxc = self.create_restricted_media() + room_id = self.create_test_room(visibility, creation_room_avatar_mxc) + + # Retrieve the media info for the room avatar. This will have been set after the + # visibility event was created and is therefore governed by it. + first_media_object = self.retrieve_media_from_store(creation_room_avatar_mxc) + + self.assert_expected_result( + joining_user_id, first_media_object, expected_before_result_bool + ) + + # Now we join the room, and examine the media object again + self.helper.join(room_id, joining_user, tok=joining_user_tok) + + self.assert_expected_result( + joining_user_id, first_media_object, expected_after_result_bool + ) + + # Do a second one that is definitely after the join + second_media_object = self.insert_room_avatar_with_attached_media(room_id) + + self.assert_expected_result( + joining_user_id, second_media_object, expected_after_result_bool + ) + + @parameterized.expand( + [ + (HistoryVisibility.WORLD_READABLE, True, True), + (HistoryVisibility.SHARED, True, True), + # Since the room avatar is set before the join, + # after the leave it is not visible at all + (HistoryVisibility.INVITED, False, False), + (HistoryVisibility.JOINED, False, False), + ] + ) + def test_room_avatar_media_visibility_to_users_that_left_room( + self, + visibility: str, + expected_before_result_bool: bool, + expected_after_result_bool: bool, + ) -> None: + """ + Test that after leaving a room, a user can not see changes to the room's avatar + """ + + leaving_user = self.register_user("leaving_user_room_avatar_test", "password") + leaving_user_id = UserID.from_string(leaving_user) + leaving_user_tok = self.login("leaving_user_room_avatar_test", "password") + + # We will need two pieces of media: one to check for from room creation and then after the leave + + creation_room_avatar_mxc = self.create_restricted_media() + room_id = self.create_test_room(visibility, creation_room_avatar_mxc) + # Retrieve the media info for the room avatar. This will have been set after the + # visibility event was created and is therefore governed by it. + first_media_object = self.retrieve_media_from_store(creation_room_avatar_mxc) + + # Join the user, or else they can not leave + self.helper.join(room_id, leaving_user, tok=leaving_user_tok) + + # Pretty sure this should always be True, we are already in the room + self.assert_expected_result(leaving_user_id, first_media_object, True) + + # Bye bye, user + self.helper.leave(room_id, leaving_user, tok=leaving_user_tok) + + # The user has left, which means the media may not be visible anymore + self.assert_expected_result( + leaving_user_id, first_media_object, expected_before_result_bool + ) + + # Change the avatar, it should not be different to the prior result + second_media_object = self.insert_room_avatar_with_attached_media(room_id) + + self.assert_expected_result( + leaving_user_id, second_media_object, expected_after_result_bool + ) + + def test_global_profile_is_visible(self) -> None: + """ + Test that a profile avatar that is not from a membership event is viewable if not limited + """ + profile_viewing_user = self.register_user("profile_viewing_user", "password") + profile_viewing_user_id = UserID.from_string(profile_viewing_user) + + # Just to simply a few spots below where the UserID object is needed + alice_user_id = UserID.from_string(self.alice_user_id) + + mxc_uri = self.create_restricted_media() + + # The profile handler function wants a Requester object specifically + alice_as_requester = create_requester(self.alice_user_id) + self.get_success( + self.profile_handler.set_avatar_url( + alice_user_id, alice_as_requester, str(mxc_uri) + ) + ) + + media_object = self.get_success(self.store.get_local_media(mxc_uri.media_id)) + assert media_object is not None + + # Should be visible by both users + self.assert_expected_result(alice_user_id, media_object, True) + self.assert_expected_result(profile_viewing_user_id, media_object, True) + + @override_config({"limit_profile_requests_to_users_who_share_rooms": True}) + def test_global_profile_is_not_visible_when_not_sharing_a_room_setting_is_enabled( + self, + ) -> None: + """ + Test that a profile avatar that is not from a membership event is not viewable + if limited by the setting "limit_profile_requests_to_users_who_share_rooms" + """ + profile_viewing_user = self.register_user("profile_viewing_user", "password") + profile_viewing_user_id = UserID.from_string(profile_viewing_user) + + # Just to simply a few spots below where the UserID object is needed + alice_user_id = UserID.from_string(self.alice_user_id) + + mxc_uri = self.create_restricted_media() + + # The profile handler function wants a Requester object specifically + alice_as_requester = create_requester(self.alice_user_id) + self.get_success( + self.profile_handler.set_avatar_url( + alice_user_id, alice_as_requester, str(mxc_uri) + ) + ) + + media_object = self.get_success(self.store.get_local_media(mxc_uri.media_id)) + assert media_object is not None + self.assert_expected_result(alice_user_id, media_object, True) + self.assert_expected_result(profile_viewing_user_id, media_object, False) diff --git a/tests/rest/client/test_media_download.py b/tests/rest/client/test_media_download.py new file mode 100644 index 0000000000..ece4b32f2d --- /dev/null +++ b/tests/rest/client/test_media_download.py @@ -0,0 +1,346 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright 2025 The Matrix.org Foundation C.I.C. +# Copyright (C) 2025 Famedly +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# + +import io +from typing import Optional + +from matrix_common.types.mxc_uri import MXCUri + +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + Membership, +) +from synapse.rest import admin +from synapse.rest.client import login, media, room +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID +from synapse.util import Clock + +from tests import unittest +from tests.test_utils import SMALL_PNG +from tests.unittest import override_config + + +class RestrictedResourceDownloadTestCase(unittest.HomeserverTestCase): + """ + Test the `/download` media endpoint for restricted media. + + Something to note: rooms here will be set to room history visibility of 'joined' + at a minimum, or the media would be visible by default + """ + + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.store = hs.get_datastores().main + self.creator = self.register_user("creator", "testpass") + self.creator_tok = self.login("creator", "testpass") + self.other_user = self.register_user("random_user", "testpass") + self.other_user_tok = self.login("random_user", "testpass") + self.other_profile_test_user = self.register_user( + "profile_test_user", "testpass" + ) + self.other_profile_test_user_tok = self.login("profile_test_user", "testpass") + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def _create_restricted_media(self, user: str) -> MXCUri: + mxc_uri = self.get_success( + self.repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(user), + restricted=True, + ) + ) + return mxc_uri + + def fetch_media( + self, + mxc_uri: MXCUri, + access_token: Optional[str] = None, + expected_code: int = 200, + ) -> None: + """ + Test retrieving the media. We do not care about the content of the media, just + that the response is correct + """ + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{mxc_uri.server_name}/{mxc_uri.media_id}", + access_token=access_token or self.creator_tok, + ) + assert channel.code == expected_code, channel.code + + def test_user_download_local_media_unrestricted(self) -> None: + """Test that unrestricted media is not affected""" + mxc_uri = self.get_success( + self.repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(self.other_user), + restricted=False, + ) + ) + # The assertion of 200 as a response code is part of the function + self.fetch_media(mxc_uri) + self.fetch_media(mxc_uri, access_token=self.other_user_tok) + + def test_download_local_media_restricted_but_pending_state(self) -> None: + """Test originating user can access media even though it is not attached""" + mxc_uri = self._create_restricted_media(self.creator) + # The creator user can see their own media + self.fetch_media(mxc_uri) + # But another user can not + self.fetch_media(mxc_uri, access_token=self.other_user_tok, expected_code=403) + + def test_user_download_local_media_attached_to_user_profile_success(self) -> None: + """Test retrieving media attached to user's profile""" + prime_mxc_uri = self._create_restricted_media(self.creator) + other_mxc_uri = self._create_restricted_media(self.other_profile_test_user) + # Inject directly to the database, we are not here to test the profile endpoint + self.get_success( + self.store.set_media_restricted_to_user_profile( + prime_mxc_uri.server_name, + prime_mxc_uri.media_id, + self.creator, + ) + ) + self.get_success( + self.store.set_media_restricted_to_user_profile( + other_mxc_uri.server_name, + other_mxc_uri.media_id, + self.other_profile_test_user, + ) + ) + + # Should be able to see their own + self.fetch_media(prime_mxc_uri, access_token=self.creator_tok) + self.fetch_media(other_mxc_uri, access_token=self.other_profile_test_user_tok) + + # Should be able to see each others + self.fetch_media(other_mxc_uri, access_token=self.creator_tok) + self.fetch_media(prime_mxc_uri, access_token=self.other_profile_test_user_tok) + + @override_config( + { + "limit_profile_requests_to_users_who_share_rooms": True, + } + ) + def test_user_download_local_media_attached_to_user_profile_failure(self) -> None: + """ + Test that limiting profile requests works as expected. Specifically, that users + that are not sharing a room can not see profile avatars + """ + + prime_mxc_uri = self._create_restricted_media(self.creator) + other_mxc_uri = self._create_restricted_media(self.other_profile_test_user) + # Inject directly to the database, we are not here to test the profile endpoint + self.get_success( + self.store.set_media_restricted_to_user_profile( + prime_mxc_uri.server_name, + prime_mxc_uri.media_id, + self.creator, + ) + ) + self.get_success( + self.store.set_media_restricted_to_user_profile( + other_mxc_uri.server_name, + other_mxc_uri.media_id, + self.other_profile_test_user, + ) + ) + + # Should be able to see their own + self.fetch_media(prime_mxc_uri, access_token=self.creator_tok) + self.fetch_media(other_mxc_uri, access_token=self.other_profile_test_user_tok) + + # Should NOT be able to see each others, since the limitation setting is enabled + self.fetch_media( + other_mxc_uri, access_token=self.creator_tok, expected_code=403 + ) + self.fetch_media( + prime_mxc_uri, + access_token=self.other_profile_test_user_tok, + expected_code=403, + ) + + def test_user_download_local_media_attached_to_message_event_success(self) -> None: + """Test that can local media attached to image event can be viewed""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + # set room history_visibility to joined, otherwise it will be 'shared' + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + _ = self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + # TODO: verify this file info is legit, because it does not match SMALL_PNG. It + # seems to work tho, oddly + image = { + "body": "test_png_upload", + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "msgtype": "m.image", + "url": str(mxc_uri), + } + json_body = self.helper.send_event( + room_id, + "m.room.message", + content=image, + tok=self.creator_tok, + expect_code=200, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + # Both users should be able to see the event + self.fetch_media(mxc_uri) + self.fetch_media(mxc_uri, access_token=self.other_user_tok) + + def test_user_download_local_media_attached_to_message_event_failure(self) -> None: + """Test that can local media attached to image event can be restricted""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + # set room history_visibility to joined + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + image = { + "body": "test_png_upload", + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "msgtype": "m.image", + "url": str(mxc_uri), + } + json_body = self.helper.send_event( + room_id, + "m.room.message", + content=image, + tok=self.creator_tok, + expect_code=200, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + # Specifically, join the user AFTER sending the attaching message + self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + self.fetch_media(mxc_uri) + # The other user was not in the room at the time the image was sent, so this + # should fail. + self.fetch_media(mxc_uri, access_token=self.other_user_tok, expected_code=403) + + def test_user_download_local_media_attached_to_state_event_success(self) -> None: + """Test that a simple membership avatar is viewable when appropriate""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + # set room history_visibility to joined + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + _ = self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + membership_content = { + EventContentFields.MEMBERSHIP: Membership.JOIN, + "avatar_url": str(mxc_uri), + } + json_body = self.helper.send_state( + room_id, + EventTypes.Member, + body=membership_content, + tok=self.creator_tok, + expect_code=200, + state_key=self.creator, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + # Both users should be able to see the media + self.fetch_media(mxc_uri) + self.fetch_media(mxc_uri, access_token=self.other_user_tok) + + def test_user_download_local_media_attached_to_state_event_failure(self) -> None: + """Test that a simple membership avatar is restricted when appropriate""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + # set room history_visibility to joined + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + membership_content = { + EventContentFields.MEMBERSHIP: Membership.JOIN, + "avatar_url": str(mxc_uri), + } + json_body = self.helper.send_state( + room_id, + EventTypes.Member, + body=membership_content, + tok=self.creator_tok, + expect_code=200, + state_key=self.creator, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + _ = self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + self.fetch_media(mxc_uri) + # This user has joined the room and can now see this image. Can't see the + # related membership event, but :man-shrug: + self.fetch_media(mxc_uri, access_token=self.other_user_tok) diff --git a/tests/rest/client/test_media_thumbnail.py b/tests/rest/client/test_media_thumbnail.py new file mode 100644 index 0000000000..587babfeee --- /dev/null +++ b/tests/rest/client/test_media_thumbnail.py @@ -0,0 +1,368 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright 2025 The Matrix.org Foundation C.I.C. +# Copyright (C) 2025 Famedly +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# + +import io +from typing import Optional + +from matrix_common.types.mxc_uri import MXCUri + +from twisted.internet.testing import MemoryReactor +from twisted.web.resource import Resource + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + Membership, +) +from synapse.rest import admin +from synapse.rest.client import login, media, room +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID +from synapse.util import Clock + +from tests import unittest +from tests.server import FakeChannel +from tests.test_utils import SMALL_PNG +from tests.unittest import override_config + + +class RestrictedResourceThumbnailTestCase(unittest.HomeserverTestCase): + """ + Test the `/thumbnail` media endpoint for restricted media. + + Something to note: rooms here will be set to room history visibility of 'joined' + at a minimum, or the media would be visible by default + """ + + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + # This is what the defaults are for both 'crop' and 'scale' as reference + # We don't need to set these, but it's good to know + # "thumbnail_sizes": [ + # {"width": 32, "height": 32, "method": "crop"}, + # {"width": 240, "height": 320, "method": "scale"}, + # ], + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.store = hs.get_datastores().main + self.creator = self.register_user("creator", "testpass") + self.creator_tok = self.login("creator", "testpass") + self.other_user = self.register_user("random_user", "testpass") + self.other_user_tok = self.login("random_user", "testpass") + self.other_profile_test_user = self.register_user( + "profile_test_user", "testpass" + ) + self.other_profile_test_user_tok = self.login("profile_test_user", "testpass") + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def _create_restricted_media(self, user: str) -> MXCUri: + """ + Insert our media directly into the database/repo. This creates the necessary + rows and sets the media as 'restricted' but does establish any attachments. + """ + mxc_uri = self.get_success( + self.repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(user), + restricted=True, + ) + ) + return mxc_uri + + def fetch_thumbnail( + self, + mxc_uri: MXCUri, + method: str = "crop", + access_token: Optional[str] = None, + expect_code: int = 200, + ) -> FakeChannel: + """ + Attempt media retrieval from the `/thumbnail` endpoint. Assert's expected code + before returning raw channel + """ + params = "?width=1&height=1&method=" + method + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/{mxc_uri.server_name}/{mxc_uri.media_id}{params}", + access_token=access_token or self.creator_tok, + ) + assert channel.code == expect_code, channel.code + return channel + + def test_user_download_local_media_thumbnail_unrestricted(self) -> None: + """Test that unrestricted media is not affected""" + # Note that 'restricted' is marked as 'False' here + content_mxc_uri = self.get_success( + self.repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(self.other_user), + restricted=False, + ) + ) + # The assertion of 200 as a response code is part of the function + self.fetch_thumbnail(content_mxc_uri) + self.fetch_thumbnail(content_mxc_uri, access_token=self.other_user_tok) + + def test_download_local_media_restricted_but_pending_state(self) -> None: + """Test originating user can access media even though it is not attached""" + mxc_uri = self._create_restricted_media(self.creator) + + # The creator user can see their own media + self.fetch_thumbnail(mxc_uri) + # But another user can not + self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok, expect_code=403) + + def test_user_download_local_media_attached_to_user_profile_success(self) -> None: + """Test retrieving media attached to user's profile""" + prime_mxc_uri = self._create_restricted_media(self.creator) + other_mxc_uri = self._create_restricted_media(self.other_profile_test_user) + # Inject directly to the database, we are not here to test the profile endpoint + self.get_success( + self.store.set_media_restricted_to_user_profile( + prime_mxc_uri.server_name, + prime_mxc_uri.media_id, + self.creator, + ) + ) + self.get_success( + self.store.set_media_restricted_to_user_profile( + other_mxc_uri.server_name, + other_mxc_uri.media_id, + self.other_profile_test_user, + ) + ) + + # Should be able to see their own + self.fetch_thumbnail(prime_mxc_uri, access_token=self.creator_tok) + self.fetch_thumbnail( + other_mxc_uri, access_token=self.other_profile_test_user_tok + ) + + # Should be able to see each others + self.fetch_thumbnail(other_mxc_uri, access_token=self.creator_tok) + self.fetch_thumbnail( + prime_mxc_uri, access_token=self.other_profile_test_user_tok + ) + + @override_config( + { + "limit_profile_requests_to_users_who_share_rooms": True, + } + ) + def test_user_download_local_media_attached_to_user_profile_failure(self) -> None: + """ + Test that limiting profile requests works as expected. Specifically, that users + that are not sharing a room can not see profile avatars + """ + + prime_mxc_uri = self._create_restricted_media(self.creator) + other_mxc_uri = self._create_restricted_media(self.other_profile_test_user) + # Inject directly to the database, we are not here to test the profile endpoint + self.get_success( + self.store.set_media_restricted_to_user_profile( + prime_mxc_uri.server_name, + prime_mxc_uri.media_id, + self.creator, + ) + ) + self.get_success( + self.store.set_media_restricted_to_user_profile( + other_mxc_uri.server_name, + other_mxc_uri.media_id, + self.other_profile_test_user, + ) + ) + + # Should be able to see their own + self.fetch_thumbnail(prime_mxc_uri, access_token=self.creator_tok) + self.fetch_thumbnail( + other_mxc_uri, access_token=self.other_profile_test_user_tok + ) + + # Should NOT be able to see each others, since the limitation setting is enabled + self.fetch_thumbnail( + other_mxc_uri, access_token=self.creator_tok, expect_code=403 + ) + self.fetch_thumbnail( + prime_mxc_uri, + access_token=self.other_profile_test_user_tok, + expect_code=403, + ) + + def test_users_download_local_media_attached_to_message_event_success(self) -> None: + """Test that can local media attached to image event can be viewed""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + # set room history_visibility to joined, otherwise it will be 'shared' + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + _ = self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + # TODO: verify this file info is legit, because it does not match SMALL_PNG. It + # seems to work tho, oddly + image = { + "body": "test_png_upload", + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "msgtype": "m.image", + "url": str(mxc_uri), + } + json_body = self.helper.send_event( + room_id, + "m.room.message", + content=image, + tok=self.creator_tok, + expect_code=200, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + # Both users should be able to see the event + self.fetch_thumbnail(mxc_uri) + self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok) + + def test_users_download_local_media_attached_to_message_event_failure(self) -> None: + """Test that can local media attached to image event can be restricted""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + # set room history_visibility to joined + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + image = { + "body": "test_png_upload", + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "msgtype": "m.image", + "url": str(mxc_uri), + } + json_body = self.helper.send_event( + room_id, + "m.room.message", + content=image, + tok=self.creator_tok, + expect_code=200, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + # Specifically, join the user AFTER sending the attaching message + self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + self.fetch_thumbnail(mxc_uri) + # The other user was not in the room at the time the image was sent, so this + # should fail. + self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok, expect_code=403) + + def test_user_download_local_media_attached_to_state_event_success(self) -> None: + """Test that a simple membership avatar is viewable when appropriate""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + # set room history_visibility to joined + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + _ = self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + membership_content = { + EventContentFields.MEMBERSHIP: Membership.JOIN, + "avatar_url": str(mxc_uri), + } + json_body = self.helper.send_state( + room_id, + EventTypes.Member, + body=membership_content, + tok=self.creator_tok, + expect_code=200, + state_key=self.creator, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + # Both users should be able to see the media + self.fetch_thumbnail(mxc_uri) + self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok) + + def test_user_download_local_media_attached_to_state_event_failure(self) -> None: + """Test that a simple membership avatar is restricted when appropriate""" + mxc_uri = self._create_restricted_media(self.creator) + room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + # set room history_visibility to joined + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.creator_tok, + ) + + membership_content = { + EventContentFields.MEMBERSHIP: Membership.JOIN, + "avatar_url": str(mxc_uri), + } + json_body = self.helper.send_state( + room_id, + EventTypes.Member, + body=membership_content, + tok=self.creator_tok, + expect_code=200, + state_key=self.creator, + attach_media_mxc=str(mxc_uri), + ) + assert "event_id" in json_body + + _ = self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + self.fetch_thumbnail(mxc_uri) + # This user has joined the room and can now see this image. Can't see the + # related membership event, but :man-shrug: + self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index dacc5639e3..62a1b05d8e 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -378,6 +378,7 @@ def send( expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, type: str = "m.room.message", + attach_media_mxc: Optional[str] = None, ) -> JsonDict: if body is None: body = "body_text_here" @@ -392,6 +393,7 @@ def send( tok, expect_code, custom_headers=custom_headers, + attach_media_mxc=attach_media_mxc, ) def send_event( @@ -403,13 +405,22 @@ def send_event( tok: Optional[str] = None, expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + attach_media_mxc: Optional[str] = None, ) -> JsonDict: if txn_id is None: txn_id = "m%s" % (str(time.time())) path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) + url_params: Dict[str, str] = {} + if tok: - path = path + "?access_token=%s" % tok + url_params["access_token"] = tok + + if attach_media_mxc: + url_params["org.matrix.msc3911.attach_media"] = attach_media_mxc + + if url_params: + path += "?" + urlencode(url_params) channel = make_request( self.reactor, @@ -474,6 +485,7 @@ def _read_write_state( expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", + attach_media_mxc: Optional[str] = None, ) -> JsonDict: """Read or write some state from a given room @@ -486,6 +498,8 @@ def _read_write_state( expect_code: The HTTP code to expect in the response state_key: method: "GET" or "PUT" for reading or writing state, respectively + attach_media_mxc: The MXC to attach to the state event, per msc3911. Only makes + sense when "PUT"ting the state Returns: The response body from the server @@ -498,8 +512,16 @@ def _read_write_state( event_type, state_key, ) + url_params: Dict[str, str] = {} + if tok: - path = path + "?access_token=%s" % tok + url_params["access_token"] = tok + + if attach_media_mxc: + url_params["org.matrix.msc3911.attach_media"] = attach_media_mxc + + if url_params: + path += "?" + urlencode(url_params) # Set request body if provided content = b"" @@ -551,6 +573,7 @@ def send_state( tok: Optional[str] = None, expect_code: int = HTTPStatus.OK, state_key: str = "", + attach_media_mxc: Optional[str] = None, ) -> JsonDict: """Set some state in a room @@ -561,6 +584,7 @@ def send_state( tok: The access token to use expect_code: The HTTP code to expect in the response state_key: + attach_media_mxc: The media MXC uri to attach to this state event Returns: The response body from the server @@ -569,7 +593,14 @@ def send_state( AssertionError: if expect_code doesn't match the HTTP code we received """ return self._read_write_state( - room_id, event_type, body, tok, expect_code, state_key, method="PUT" + room_id, + event_type, + body, + tok, + expect_code, + state_key, + method="PUT", + attach_media_mxc=attach_media_mxc, ) def upload_media( From dbf7649f868e5263f1620823d8af79a903766168 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Tue, 9 Sep 2025 17:11:52 +0200 Subject: [PATCH 20/35] feat: add media permission check on copy endpoint --- synapse/media/media_repository.py | 2 + synapse/rest/client/media.py | 11 +- .../databases/main/media_repository.py | 3 +- tests/rest/client/test_media.py | 211 ++++++++++++++++++ 4 files changed, 224 insertions(+), 3 deletions(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index d28319c00f..b115b1071d 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -458,6 +458,8 @@ async def get_local_media_info( # The file has been uploaded, so stop looping if media_info.media_length is not None: + if isinstance(request.requester, Requester): + await self.is_media_visible(request.requester.user, media_info) return media_info # Check if the media ID has expired and still hasn't been uploaded to. diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 80b9b8c65b..ada6c34bc4 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -349,18 +349,25 @@ async def on_POST( # Optionally parse request body, must be a JSON object, but no required params. _ = parse_json_object_from_request(request, allow_empty_body=True) + max_timeout_ms = parse_integer( + request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS + ) + max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) media_info: Union[LocalMedia, RemoteMedia, None] = None if self._is_mine_server_name(server_name): - media_info = await self.store.get_local_media(media_id) + media_info = await self.media_repo.get_local_media_info( + request, media_id, max_timeout_ms + ) else: media_info = await self.media_repo.get_remote_media_info( server_name, media_id, - DEFAULT_MAX_TIMEOUT_MS, + max_timeout_ms, request.getClientAddress().host, use_federation=True, allow_authenticated=True, + requester=requester, ) if not media_info: diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 0a245dea0f..6dd6d10b4c 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -768,7 +768,8 @@ async def get_cached_remote_media( if row is None: return row restriction_info = None - if row[9] is not None and row[9] is True: + + if row[9] is not None and row[9]: restriction_info = await self.get_media_restrictions(origin, media_id) return RemoteMedia( diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 1b1cf7641a..6a6260affa 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -3201,6 +3201,7 @@ class CopyRestrictedResource(unittest.HomeserverTestCase): media.register_servlets, login.register_servlets, admin.register_servlets, + room.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -3210,6 +3211,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository() + self.profile_handler = self.hs.get_profile_handler() self.user = self.register_user("user", "testpass") self.user_tok = self.login("user", "testpass") self.other_user = self.register_user("other", "testpass") @@ -3224,6 +3226,25 @@ def test_copy_local_restricted_resource(self) -> None: """ Tests that the new copy endpoint creates a new mxc uri for restricted resource. """ + # Create a private room + room_id = self.helper.create_room_as( + self.user, + is_public=False, + tok=self.user_tok, + extra_content={ + "initial_state": [ + { + "type": EventTypes.RoomHistoryVisibility, + "state_key": "", + "content": {"history_visibility": HistoryVisibility.JOINED}, + }, + ] + }, + ) + # Invite the other user + self.helper.invite(room_id, self.user, self.other_user, tok=self.user_tok) + self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + # The media is created with user_tok content = io.BytesIO(SMALL_PNG) content_uri = self.get_success( @@ -3238,6 +3259,24 @@ def test_copy_local_restricted_resource(self) -> None: ) media_id = content_uri.media_id + # User sends a message with media + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{str(time.time())}?org.matrix.msc3911.attach_media={str(content_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + assert "event_id" in channel.json_body + event_id = channel.json_body["event_id"] + restrictions = self.get_success( + self.hs.get_datastores().main.get_media_restrictions( + content_uri.server_name, content_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + # The other_user copies the media from local server channel = self.make_request( "POST", @@ -3269,6 +3308,103 @@ def test_copy_local_restricted_resource(self) -> None: # Check if media is unattached to any event or user profile yet. assert copied_media.attachments is None + def test_copy_local_restricted_resource_fails_when_requester_does_not_have_access( + self, + ) -> None: + """ + Tests that the new copy endpoint performs permission checks and it prevents the + copy when the requester does not have access to the original media. + """ + # Create a private room + room_id = self.helper.create_room_as( + self.user, + is_public=False, + tok=self.user_tok, + extra_content={ + "initial_state": [ + { + "type": EventTypes.RoomHistoryVisibility, + "state_key": "", + "content": {"history_visibility": HistoryVisibility.JOINED}, + }, + ] + }, + ) + + # Create the media content + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(self.user), + restricted=True, + ) + ) + # User sends a message with media + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{str(time.time())}?org.matrix.msc3911.attach_media={str(content_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + assert "event_id" in channel.json_body + event_id = channel.json_body["event_id"] + restrictions = self.get_success( + self.hs.get_datastores().main.get_media_restrictions( + content_uri.server_name, content_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + + # Invite the other user + self.helper.invite(room_id, self.user, self.other_user, tok=self.user_tok) + self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + # User who does not have access to the media tries to copy it. + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{content_uri.media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 403) + + @override_config( + { + "limit_profile_requests_to_users_who_share_rooms": True, + } + ) + def test_copy_local_restricted_resource_fails_when_profile_lookup_is_not_allowed( + self, + ) -> None: + # User setup a profile + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(self.user), + restricted=True, + ) + ) + user_id = UserID.from_string(self.user) + self.get_success( + self.profile_handler.set_avatar_url( + user_id, create_requester(user_id), str(content_uri) + ) + ) + # The users do not share any rooms, and other user tries to copy the profile picture + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{content_uri.media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 403) + def test_copy_remote_restricted_resource(self) -> None: """ Tests that the new copy endpoint creates a new mxc uri for restricted resource. @@ -3294,9 +3430,26 @@ def test_copy_remote_restricted_resource(self) -> None: upload_name="test.png", filesystem_id=remote_file_id, sha256=remote_file_id, + restricted=True, ) ) + # Remote media is attached to a user profile + remote_user_id = f"@remote-user:{remote_server}" + self.get_success( + self.hs.get_datastores().main.set_media_restricted_to_user_profile( + remote_server, media_id, remote_user_id + ) + ) + remote_media = self.get_success( + self.hs.get_datastores().main.get_cached_remote_media( + remote_server, media_id + ) + ) + assert remote_media is not None + assert remote_media.attachments is not None + assert str(remote_media.attachments.profile_user_id) == remote_user_id + # The other_user copies the media from remote server channel = self.make_request( "POST", @@ -3330,6 +3483,64 @@ def test_copy_remote_restricted_resource(self) -> None: # Check if copied media is unattached to any event or user profile yet. assert copied_media.attachments is None + @override_config( + { + "limit_profile_requests_to_users_who_share_rooms": True, + } + ) + def test_copy_remote_restricted_resource_fails_when_requester_does_not_have_access( + self, + ) -> None: + # Create remote media + remote_server = "remoteserver.com" + remote_file_id = "remote1" + file_info = FileInfo(server_name=remote_server, file_id=remote_file_id) + + media_storage = self.hs.get_media_repository().media_storage + ctx = media_storage.store_into_file(file_info) + (f, _) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + media_id = "remotemedia" + self.get_success( + self.hs.get_datastores().main.store_cached_remote_media( + origin=remote_server, + media_id=media_id, + media_type="image/png", + media_length=1, + time_now_ms=self.clock.time_msec(), + upload_name="test.png", + filesystem_id=remote_file_id, + sha256=remote_file_id, + restricted=True, + ) + ) + + # Media is attached to a user profile + remote_user_id = f"@remote-user:{remote_server}" + self.get_success( + self.hs.get_datastores().main.set_media_restricted_to_user_profile( + remote_server, media_id, remote_user_id + ) + ) + remote_media = self.get_success( + self.hs.get_datastores().main.get_cached_remote_media( + remote_server, media_id + ) + ) + assert remote_media is not None + assert remote_media.attachments is not None + assert str(remote_media.attachments.profile_user_id) == remote_user_id + + # The other user tries to copy that media from remote server, but fails because + # user does not have the access to the profile_user_id + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{remote_server}/{media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 403) + class RestrictedMediaVisibilityTestCase(unittest.HomeserverTestCase): servlets = [ From 700d11e46db9612fc108bfa0380d013fa2551453 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Wed, 10 Sep 2025 10:55:13 -0500 Subject: [PATCH 21/35] MSC3911 AP7.5: Accepting restrictions received from federation --- synapse/media/media_repository.py | 46 +++++++- .../databases/main/media_repository.py | 93 ++++++++++------ tests/rest/client/test_media.py | 100 +++++++++++++++++- 3 files changed, 196 insertions(+), 43 deletions(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index b115b1071d..5020df60c1 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -75,8 +75,9 @@ MediaRestrictions, RemoteMedia, ) -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.types.state import StateFilter +from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import random_string @@ -1190,10 +1191,16 @@ async def _federation_download_remote_file( ) # if we had to fall back to the _matrix/media endpoint it will only return # the headers and length, check the length of the tuple before unpacking + attachment_dict: JsonDict if len(res) == 3: - length, headers, json = res + length, headers, json_bytes = res + if json_bytes: + attachment_dict = json_decoder.decode(json_bytes.decode()) else: length, headers = res + # This is set to an empty {} just as it is responded when media is + # not restricted, thus maintaining backwards compatibility + attachment_dict = {} except RequestSendFailed as e: logger.warning( "Request failed fetching remote media %s/%s: %r", @@ -1245,6 +1252,21 @@ async def _federation_download_remote_file( # alternative where we call `finish()` *after* this, where we could # end up having an entry in the DB but fail to write the files to # the storage providers. + + # The unstable prefix on 'restrictions' will be here. Do not save that to + # the database, but filter it out. This is the companion to it's opposite in + # MediaRestrictions.to_dict() which adds it while unstable. + if "org.matrix.msc3911.restrictions" in attachment_dict: + restrictions_values = attachment_dict.pop( + "org.matrix.msc3911.restrictions" + ) + attachment_dict["restrictions"] = restrictions_values + + # This can come in as 'falsey'(like '{}' or 'b""') so if this happens it has + # no restrictions. If it was restricted remotely, but had no attachments, + # then it should not have come across federation + restricted = True if "restrictions" in attachment_dict else False + await self.store.store_cached_remote_media( origin=server_name, media_id=media_id, @@ -1254,7 +1276,22 @@ async def _federation_download_remote_file( media_length=length, filesystem_id=file_id, sha256=sha256writer.hexdigest(), + restricted=restricted, ) + # TODO: Decide about raising here? It will delete the media from the + # disk but will not remove the restricted flag from the remote media + # entry that just got wrote. Is this important? According to the comment + # blocks above the last statement, it could raise a constraint violation + # which would block this from being called. But if it is racing, we may have + # been here before. Should this be gracefully handled(and basically ignored)? + # To keep the 'media_attachments' table smaller, unrestricted media does not + # have a row, only the restricted column for both local and remote media + attachments: Optional[MediaRestrictions] = None + if attachment_dict: + attachments = MediaRestrictions(**attachment_dict["restrictions"]) + await self.store.set_media_restrictions( + server_name, media_id, attachment_dict + ) logger.debug("Stored remote media in file %r", fname) @@ -1275,9 +1312,8 @@ async def _federation_download_remote_file( quarantined_by=None, authenticated=authenticated, sha256=sha256writer.hexdigest(), - # Update this when the federation responses are updated - restricted=False, - attachments=None, + restricted=restricted, + attachments=attachments, ) def _get_thumbnail_requirements( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 6dd6d10b4c..18ad97959e 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1175,40 +1175,35 @@ async def get_media_restrictions( return None - async def set_media_restricted_to_event_id( - self, - server_name: str, - media_id: str, - event_id: str, + async def set_media_restrictions( + self, server_name: str, media_id: str, json_dict: JsonDict ) -> None: """ - Add the media restrictions to a given Event ID to the database + Add the media restrictions for a given server_name and media_id to the database Args: server_name: media_id: - event_id: The Event ID to restrict the media to + json_dict: The dict with the restrictions Raises: SynapseError if the media already has restrictions on it """ - await self.db_pool.runInteraction( - "set_media_restricted_to_event_id", - self.set_media_restricted_to_event_id_txn, + return await self.db_pool.runInteraction( + "set_media_restrictions", + self.set_media_restrictions_txn, server_name=server_name, media_id=media_id, - event_id=event_id, + json_dict=json_dict, ) - def set_media_restricted_to_event_id_txn( + def set_media_restrictions_txn( self, txn: LoggingTransaction, - *, server_name: str, media_id: str, - event_id: str, + json_dict: JsonDict, ) -> None: - json_object = {"restrictions": {"event_id": event_id}} try: self.db_pool.simple_insert_txn( txn, @@ -1216,7 +1211,7 @@ def set_media_restricted_to_event_id_txn( { "server_name": server_name, "media_id": media_id, - "restrictions_json": json_encoder.encode(json_object), + "restrictions_json": json_encoder.encode(json_dict), }, ) except self.db_pool.engine.module.IntegrityError: @@ -1229,6 +1224,47 @@ def set_media_restricted_to_event_id_txn( errcode=Codes.INVALID_PARAM, ) + async def set_media_restricted_to_event_id( + self, + server_name: str, + media_id: str, + event_id: str, + ) -> None: + """ + Add the media restrictions to a given Event ID to the database + + Args: + server_name: + media_id: + event_id: The Event ID to restrict the media to + + Raises: + SynapseError if the media already has restrictions on it + """ + await self.db_pool.runInteraction( + "set_media_restricted_to_event_id", + self.set_media_restricted_to_event_id_txn, + server_name=server_name, + media_id=media_id, + event_id=event_id, + ) + + def set_media_restricted_to_event_id_txn( + self, + txn: LoggingTransaction, + *, + server_name: str, + media_id: str, + event_id: str, + ) -> None: + json_object = {"restrictions": {"event_id": event_id}} + self.set_media_restrictions_txn( + txn, + server_name=server_name, + media_id=media_id, + json_dict=json_object, + ) + async def set_media_restricted_to_user_profile( self, server_name: str, @@ -1266,22 +1302,9 @@ def set_media_restricted_to_user_profile_txn( profile_user_id: str, ) -> None: json_object = {"restrictions": {"profile_user_id": profile_user_id}} - try: - self.db_pool.simple_insert_txn( - txn, - "media_attachments", - { - "server_name": server_name, - "media_id": media_id, - "restrictions_json": json_encoder.encode(json_object), - }, - ) - except self.db_pool.engine.module.IntegrityError: - # For sqlite, a unique constraint violation is an integrity error. For - # psycopg2, a UniqueViolation is a subclass of IntegrityError, so this - # covers both. - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"This media, '{media_id}' already has restrictions set.", - errcode=Codes.INVALID_PARAM, - ) + self.set_media_restrictions_txn( + txn, + server_name=server_name, + media_id=media_id, + json_dict=json_object, + ) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 6a6260affa..e2a25efc26 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -28,7 +28,7 @@ from contextlib import nullcontext from http import HTTPStatus from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib import parse from urllib.parse import quote, urlencode @@ -65,10 +65,13 @@ from synapse.rest import admin from synapse.rest.client import login, media, room from synapse.server import HomeServer -from synapse.storage.databases.main.media_repository import LocalMedia +from synapse.storage.databases.main.media_repository import ( + LocalMedia, + MediaRestrictions, +) from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock, json_encoder -from synapse.util.stringutils import parse_and_validate_mxc_uri +from synapse.util.stringutils import parse_and_validate_mxc_uri, random_string from tests import unittest from tests.media.test_media_storage import ( @@ -4429,3 +4432,94 @@ def test_global_profile_is_not_visible_when_not_sharing_a_room_setting_is_enable assert media_object is not None self.assert_expected_result(alice_user_id, media_object, True) self.assert_expected_result(profile_viewing_user_id, media_object, False) + + +class FederationClientDownloadTestCase(unittest.HomeserverTestCase): + test_image = small_png + headers = { + b"Content-Length": [b"%d" % (len(test_image.data))], + b"Content-Type": [test_image.content_type], + b"Content-Disposition": [b"inline"], + } + + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # Mock out the homeserver's MatrixFederationHttpClient + client = Mock() + federation_get_file = AsyncMock() + client.federation_get_file = federation_get_file + self.fed_client_mock = federation_get_file + + hs = self.setup_test_homeserver(federation_http_client=client) + + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.media_repo = hs.get_media_repository() + + self.remote_server = "example.com" + # mapping of media_id -> byte string of the json with the restrictions + self.media_id_data: Dict[str, bytes] = {} + + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + def generate_remote_media_id_and_restrictions( + self, json_dict_response: Optional[JsonDict] = None + ) -> str: + media_id = random_string(24) + byte_string = b"{}" + if json_dict_response: + byte_string = json_encoder.encode(json_dict_response).encode() + + self.media_id_data[media_id] = byte_string + return media_id + + def make_request_for_media( + self, json_dict_response: Optional[JsonDict] = None, expected_code: int = 200 + ) -> str: + # Generate media id and restrictions based on SMALL_PNG + media_id = self.generate_remote_media_id_and_restrictions(json_dict_response) + self.fed_client_mock.return_value = ( + 67, + self.headers, + self.media_id_data[media_id], + ) + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.remote_server}/{media_id}", + shorthand=False, + access_token=self.tok, + ) + + self.assertEqual(channel.code, expected_code) + + return media_id + + def test_downloading_remote_media_with_restrictions_is_in_database(self) -> None: + # Note the unstable prefix is filtered out properly before persistence + media_id = self.make_request_for_media( + {"org.matrix.msc3911.restrictions": {"profile_user_id": "@bob:example.com"}} + ) + restrictions = self.get_success( + self.store.get_media_restrictions(self.remote_server, media_id) + ) + assert isinstance(restrictions, MediaRestrictions) + assert restrictions.profile_user_id is not None + assert restrictions.profile_user_id.to_string() == "@bob:example.com" + + def test_downloading_remote_media_with_no_restrictions_does_not_save_to_db( + self, + ) -> None: + media_id = self.make_request_for_media() + restrictions = self.get_success( + self.store.get_media_restrictions(self.remote_server, media_id) + ) + assert restrictions is None From 5fc02642453984fe295be631a0a77f2b8db5c718 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Thu, 11 Sep 2025 06:37:13 -0500 Subject: [PATCH 22/35] MSC3911 AP10: Ensure backwards compatibility --- tests/rest/client/test_media.py | 521 ++++++++++++++++++++++++++++++++ 1 file changed, 521 insertions(+) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index e2a25efc26..fa703817ac 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -4523,3 +4523,524 @@ def test_downloading_remote_media_with_no_restrictions_does_not_save_to_db( self.store.get_media_restrictions(self.remote_server, media_id) ) assert restrictions is None + + +configs_2 = [ + {"enable_restricted_media": True}, + {"enable_restricted_media": False}, +] + + +@parameterized_class(configs_2) +class RestrictedMediaBackwardCompatTestCase(unittest.HomeserverTestCase): + """ + Test that restricted media can be downloaded if MSC3911 is enabled/disabled + """ + + enable_restricted_media: bool + + other_server_name = "remote-server.com" + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.clock = clock + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + config["enable_authenticated_media"] = True + config["experimental_features"] = { + "msc3911_enabled": self.enable_restricted_media + } + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.client = hs.get_federation_http_client() + self.store = hs.get_datastores().main + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + self.other_user = self.register_user("other_user", "password") + self.other_user_tok = self.login("other_user", "password") + + self.room_id = self.helper.create_room_as(self.user, True, tok=self.tok) + assert self.room_id is not None + self.helper.join(self.room_id, self.other_user, tok=self.other_user_tok) + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def upload_restricted_media(self) -> str: + """ + Upload media to the MSC3911 /media/upload endpoint. This will ensure the + restricted flag is set. Used for the authenticated ETag test + """ + # upload some local media with restrictions on + channel = self.make_request( + "POST", + "_matrix/client/unstable/org.matrix.msc3911/media/upload?filename=test_png_upload", + SMALL_PNG, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(67))], + ) + self.assertEqual(channel.code, 200, channel.json_body) + + res = channel.json_body.get("content_uri") + assert res is not None + _, restricted_uri = res.rsplit("/", maxsplit=1) + + return restricted_uri + + def upload_unrestricted_media(self) -> str: + """ + Upload media to the existing /_matrix/media/.../upload endpoint. This will not + set the restricted flag. Used for the authenticated ETag test + """ + # upload another using the old endpoint so it is not restricted + channel = self.make_request( + "POST", + "_matrix/media/v3/upload?filename=test_png_upload", + SMALL_PNG, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(67))], + ) + self.assertEqual(channel.code, 200, channel.json_body) + res = channel.json_body.get("content_uri") + assert res is not None + _, unrestricted_uri = res.rsplit("/", maxsplit=1) + + return unrestricted_uri + + def inject_local_media_and_send_event(self, authed: bool, restricted: bool) -> str: + """ + Inject the necessary database rows to ensure availability of a piece of local + media (optionally with restriction and authentication) + + Then send a message event that attaches this media to an event in the room. This + should set the restriction correctly + """ + + media_id = random_string(24) + file_id = media_id + file_info = FileInfo(None, file_id=file_id) + + media_storage = self.hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + + self.get_success( + self.store.db_pool.simple_insert( + "local_media_repository", + { + "media_id": media_id, + "media_type": "image/png", + "created_ts": self.clock.time_msec(), + "upload_name": "test_local", + "media_length": 67, + "user_id": self.other_user, + "url_cache": None, + "authenticated": authed, + "restricted": restricted, + }, + desc="store_local_media", + ) + ) + + # ensure we have thumbnails for the non-dynamic code path + self.get_success( + self.repo._generate_thumbnails(None, media_id, file_id, "image/png") + ) + + mxc_uri_str = f"mxc://test/{media_id}" + maybe_attach_media = None + if restricted: + maybe_attach_media = mxc_uri_str + image = { + "body": "test_png_upload", + "info": {"h": 1, "mimetype": "image/png", "size": 67, "w": 1}, + "msgtype": "m.image", + "url": mxc_uri_str, + } + + self.helper.send_event( + self.room_id, + EventTypes.Message, + content=image, + tok=self.other_user_tok, + attach_media_mxc=maybe_attach_media, + ) + + # We know it's the local server, so the server name is "test" + return media_id + + def inject_remote_media(self, restricted: bool) -> str: + """ + Inject the necessary database rows to ensure availability of a piece of remote + media (optionally with restriction). Abuse the profile_user_id field instead of + simulating a room + """ + + media_id = random_string(24) + file_id = media_id + file_info = FileInfo(server_name=self.other_server_name, file_id=file_id) + + media_storage = self.hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + + # we write the authenticated status when storing media, so this should pick up + # config and authenticate the media + self.get_success( + self.store.store_cached_remote_media( + origin=self.other_server_name, + media_id=media_id, + media_type="image/png", + media_length=1, + time_now_ms=self.clock.time_msec(), + upload_name="remote_test.png", + filesystem_id=file_id, + sha256=file_id, + restricted=restricted, + ) + ) + # add restrictions if appropriate, separate database call, use arg + if restricted: + # Since we are not going to generate a remote room and therefore have an + # event_id, just use the profile avatar(never mind that the user doesn't + # exist, the server doesn't care) + self.get_success( + self.store.set_media_restricted_to_user_profile( + self.other_server_name, + media_id, + f"@wilson:{self.other_server_name}", + ) + ) + + # ensure we have thumbnails for the non-dynamic code path + self.get_success( + self.repo._generate_thumbnails( + self.other_server_name, media_id, file_id, "image/png" + ) + ) + # The server is always going to be `self.other_server_name` so just need the media_id + return media_id + + def test_authed_restricted_local_media(self) -> None: + """ + Test that authenticated and restricted media is not available over the old + unauthenticated endpoints + """ + + restricted_media_id = self.inject_local_media_and_send_event( + authed=True, restricted=True + ) + # request media over authenticated endpoint, should be found + channel1 = self.make_request( + "GET", + f"_matrix/client/v1/media/download/test/{restricted_media_id}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel1.code, 200, channel1) + + # request same media over unauthenticated media, should raise 404 not found + channel2 = self.make_request( + "GET", + f"_matrix/media/v3/download/test/{restricted_media_id}", + shorthand=False, + ) + self.assertEqual(channel2.code, 404, channel2) + + # check thumbnails as well + params = "?width=32&height=32&method=crop" + channel3 = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/test/{restricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel3.code, 200, channel3) + + params = "?width=32&height=32&method=crop" + channel4 = self.make_request( + "GET", + f"/_matrix/media/r0/thumbnail/test/{restricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel4.code, 404, channel4) + + def test_unauthed_restricted_local_media(self) -> None: + """ + Test that unauthenticated but(somehow) restricted media is available over the old + unauthenticated endpoints + """ + # First test round, restricted media + restricted_media_id = self.inject_local_media_and_send_event( + authed=False, restricted=True + ) + # request media over authenticated endpoint, should be found + channel1 = self.make_request( + "GET", + f"_matrix/client/v1/media/download/test/{restricted_media_id}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel1.code, 200, channel1) + + # request same media over unauthenticated media, should be found + channel2 = self.make_request( + "GET", + f"_matrix/media/v3/download/test/{restricted_media_id}", + shorthand=False, + ) + self.assertEqual(channel2.code, 200, channel2) + + # check thumbnails as well + params = "?width=32&height=32&method=crop" + channel3 = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/test/{restricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel3.code, 200, channel3) + + params = "?width=32&height=32&method=crop" + channel4 = self.make_request( + "GET", + f"/_matrix/media/r0/thumbnail/test/{restricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel4.code, 200, channel4) + + def test_authed_unrestricted_local_media(self) -> None: + """ + Test that authenticated and restricted media is not available over the old + unauthenticated endpoints + """ + + unrestricted_media_id = self.inject_local_media_and_send_event( + authed=True, restricted=False + ) + # request media over authenticated endpoint, should be found + channel1 = self.make_request( + "GET", + f"_matrix/client/v1/media/download/test/{unrestricted_media_id}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel1.code, 200, channel1) + + # request same media over unauthenticated media, should raise 404 not found + channel2 = self.make_request( + "GET", + f"_matrix/media/v3/download/test/{unrestricted_media_id}", + shorthand=False, + ) + self.assertEqual(channel2.code, 404, channel2) + + # check thumbnails as well + params = "?width=32&height=32&method=crop" + channel3 = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/test/{unrestricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel3.code, 200, channel3) + + params = "?width=32&height=32&method=crop" + channel4 = self.make_request( + "GET", + f"/_matrix/media/r0/thumbnail/test/{unrestricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel4.code, 404, channel4) + + def test_unauthed_unrestricted_local_media(self) -> None: + """ + Test that unauthenticated but(somehow) restricted media is available over the old + unauthenticated endpoints + """ + # First test round, restricted media + unrestricted_media_id = self.inject_local_media_and_send_event( + authed=False, restricted=False + ) + # request media over authenticated endpoint, should be found + channel1 = self.make_request( + "GET", + f"_matrix/client/v1/media/download/test/{unrestricted_media_id}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel1.code, 200, channel1) + + # request same media over unauthenticated media, should be found + channel2 = self.make_request( + "GET", + f"_matrix/media/v3/download/test/{unrestricted_media_id}", + shorthand=False, + ) + self.assertEqual(channel2.code, 200, channel2) + + # check thumbnails as well + params = "?width=32&height=32&method=crop" + channel3 = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/test/{unrestricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel3.code, 200, channel3) + + params = "?width=32&height=32&method=crop" + channel4 = self.make_request( + "GET", + f"/_matrix/media/r0/thumbnail/test/{unrestricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel4.code, 200, channel4) + + def test_restricted_remote_media(self) -> None: + restricted_media_id = self.inject_remote_media(restricted=True) + channel1 = self.make_request( + "GET", + f"_matrix/client/v1/media/download/{self.other_server_name}/{restricted_media_id}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel1.code, 200, channel1) + + channel2 = self.make_request( + "GET", + f"_matrix/media/v3/download/{self.other_server_name}/{restricted_media_id}", + shorthand=False, + ) + self.assertEqual(channel2.code, 404, channel2) + + params = "?width=32&height=32&method=crop" + channel3 = self.make_request( + "GET", + f"/_matrix/client/v1/media/thumbnail/{self.other_server_name}/{restricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel3.code, 200, channel3) + + channel4 = self.make_request( + "GET", + f"/_matrix/media/r0/thumbnail/{self.other_server_name}/{restricted_media_id}{params}", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel4.code, 404, channel4) + + def test_authenticated_media_etag(self) -> None: + """ + Test that ETag works correctly with authenticated media over client + APIs. This is largely copied from AuthenticatedMediaTestCase above, except to + adjust to MSC3911 endpoints and to use this TestCase's helpers + """ + # upload some local media with authentication on + if self.enable_restricted_media: + media_id = self.upload_restricted_media() + else: + media_id = self.upload_unrestricted_media() + + # Check standard media endpoint + self._check_caching(f"/download/test/{media_id}") + + # check thumbnails as well + params = "?width=32&height=32&method=crop" + self._check_caching(f"/thumbnail/test/{media_id}{params}") + + # Remote media too? + remote_media_id = self.inject_remote_media(restricted=True) + self._check_caching(f"/download/{self.other_server_name}/{remote_media_id}") + + params = "?width=32&height=32&method=crop" + self._check_caching( + f"/thumbnail/{self.other_server_name}/{remote_media_id}{params}" + ) + + def _check_caching(self, path: str) -> None: + """ + Checks that: + 1. fetching the path returns an ETag header + 2. refetching with the ETag returns a 304 without a body + 3. refetching with the ETag but through unauthenticated endpoint + returns 404 + """ + + # Request media over authenticated endpoint, should be found + channel1 = self.make_request( + "GET", + f"/_matrix/client/v1/media{path}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel1.code, 200) + + # Should have a single ETag field + etags = channel1.headers.getRawHeaders("ETag") + self.assertIsNotNone(etags) + assert etags is not None # For mypy + self.assertEqual(len(etags), 1) + etag = etags[0] + + # Refetching with the etag should result in 304 and empty body. + channel2 = self.make_request( + "GET", + f"/_matrix/client/v1/media{path}", + access_token=self.tok, + shorthand=False, + custom_headers=[("If-None-Match", etag)], + ) + self.assertEqual(channel2.code, 304) + self.assertEqual(channel2.is_finished(), True) + self.assertNotIn("body", channel2.result) + + # Refetching with the etag but no access token should result in 404. + channel3 = self.make_request( + "GET", + f"/_matrix/media/r0{path}", + shorthand=False, + custom_headers=[("If-None-Match", etag)], + ) + self.assertEqual(channel3.code, 404) From 0f192bdf4b28d091b1da0fbdfa21b52e2e08f19b Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Fri, 12 Sep 2025 14:00:21 +0200 Subject: [PATCH 23/35] chore: Abstract media handling functions to allow using them without having to be a media worker. This created an http replication endpoint on the media worker, and it will now need a replication listener declared in its configuration file. Make sure to add the worker to the `instance_map`, and it is going to be best practice going forward to use a new list `media_repo_instances` to name each worker that can have that replication listener. --- synapse/config/homeserver.py | 2 +- synapse/config/repository.py | 43 +- synapse/config/workers.py | 10 + .../federation/transport/server/federation.py | 3 + synapse/media/media_repository.py | 708 +++++++++++------- synapse/replication/http/__init__.py | 2 + synapse/replication/http/media.py | 100 +++ synapse/rest/client/media.py | 24 +- synapse/rest/media/create_resource.py | 5 +- .../rest/media/media_repository_resource.py | 2 + synapse/server.py | 12 +- tests/media/test_media_storage.py | 4 +- tests/media/test_url_previewer.py | 2 + tests/replication/test_multi_media_repo.py | 508 ++++++++++++- tests/rest/client/test_media.py | 95 ++- tests/rest/media/test_domain_blocking.py | 6 +- tests/rest/media/test_url_preview.py | 6 + 17 files changed, 1223 insertions(+), 309 deletions(-) create mode 100644 synapse/replication/http/media.py diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 0b2413a83b..257fda084d 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -75,6 +75,7 @@ class HomeServerConfig(RootConfig): DatabaseConfig, LoggingConfig, RatelimitConfig, + WorkerConfig, ContentRepositoryConfig, OembedConfig, CaptchaConfig, @@ -103,7 +104,6 @@ class HomeServerConfig(RootConfig): RoomDirectoryConfig, ThirdPartyRulesConfig, TracerConfig, - WorkerConfig, RedisConfig, ExperimentalConfig, BackgroundUpdateConfig, diff --git a/synapse/config/repository.py b/synapse/config/repository.py index e6a5064c16..6acf634263 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -134,19 +134,46 @@ class ContentRepositoryConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. - if ( - config.get("enable_media_repo", True) is False - and config.get("worker_app") != "synapse.app.media_repository" - ): - self.can_load_media_repo = False - return + + workers_doing_media_duty = self.root.worker.workers_doing_media_duty + # It is expected by many places in Synapse and its unit tests that either there + # are no media repos, or only one, or every worker is one. If we have our new + # proper list, use that to decide if we are supposed to be handling these duties + if not workers_doing_media_duty: + if ( + config.get("enable_media_repo", True) is False + and config.get("worker_app") != "synapse.app.media_repository" + ): + self.can_load_media_repo = False + return + self.can_load_media_repo = True + else: + if self.root.worker.instance_name not in workers_doing_media_duty: + self.can_load_media_repo = False + return self.can_load_media_repo = True # Whether this instance should be the one to run the background jobs to # e.g clean up old URL previews. - self.media_instance_running_background_jobs = config.get( - "media_instance_running_background_jobs", + # We prefer the first worker that is on the list to be responsible. However, + # backwards compatible the old setting and allow it to override the list. + # + # The URLPreviewer is the only thing that cares about this. Refactoring this to + # not stomp all over it's own feet doing the same work twice(maybe a mod on + # sha256?) and then it would not matter which media worker does the job. + media_instance_running_background_jobs_from_list = None + if self.root.worker.workers_doing_media_duty: + media_instance_running_background_jobs_from_list = ( + self.root.worker.workers_doing_media_duty[0] + ) + media_instance_running_background_jobs = config.get( + "media_instance_running_background_jobs" + ) + + self.media_instance_running_background_jobs = ( + media_instance_running_background_jobs + or media_instance_running_background_jobs_from_list ) self.max_upload_size = self.parse_size(config.get("max_upload_size", "50M")) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 825ba78482..3152e48c1c 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -479,6 +479,16 @@ def read_config( self.instance_map[instance] ) + self.workers_doing_media_duty = config.get("media_repo_instances", []) + # I would rather do this bit below, but the behavior of Synapse is rather lax. + # Documented what I mean in config/repository.py + # self.workers_doing_media_duty = self._worker_names_performing_this_duty( + # config, + # "enable_media_repo", + # "synapse.app.media_repository", + # "media_repo_instances", + # ) + def _should_this_worker_perform_duty( self, config: Dict[str, Any], diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 3ebe2756ea..3bd6a9acab 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -51,6 +51,7 @@ ) from synapse.http.site import SynapseRequest from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS +from synapse.media.media_repository import MediaRepository from synapse.media.thumbnailer import ThumbnailProvider from synapse.types import JsonDict from synapse.util import SYNAPSE_VERSION @@ -851,6 +852,7 @@ def __init__( super().__init__(hs, authenticator, ratelimiter, server_name) self.media_repo = self.hs.get_media_repository() self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails + assert isinstance(self.media_repo, MediaRepository) self.thumbnail_provider = ThumbnailProvider( hs, self.media_repo, self.media_repo.media_storage ) @@ -880,6 +882,7 @@ async def on_GET( await self.thumbnail_provider.respond_local_thumbnail( request, media_id, width, height, method, m_type, max_timeout_ms, True ) + assert isinstance(self.media_repo, MediaRepository) self.media_repo.mark_recently_accessed(None, media_id) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 5020df60c1..5eb583d018 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -70,6 +70,7 @@ from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.media import ReplicationCopyMediaServlet from synapse.storage.databases.main.media_repository import ( LocalMedia, MediaRestrictions, @@ -102,7 +103,7 @@ MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour -class MediaRepository: +class AbstractMediaRepository: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() @@ -110,6 +111,412 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastores().main + self._is_mine_server_name = hs.is_mine_server_name + self.enable_media_restriction = self.hs.config.experimental.msc3911_enabled + + @trace + async def create_media_id_without_expiration( + self, auth_user: UserID, restricted: bool = False + ) -> MXCUri: + """Create and store a media ID for a local user and return the MXC URI and its + expiration. + Args: + auth_user: The user_id of the uploader + restricted: If this is to be considered restricted media + Returns: + A MXC URI of the stored content. + """ + media_id = random_string(24) + now = self.clock.time_msec() + await self.store.store_local_media_id( + media_id=media_id, + time_now_ms=now, + user_id=auth_user, + restricted=restricted, + ) + return MXCUri.from_str(f"mxc://{self.server_name}/{media_id}") + + async def get_media_info(self, mxc_uri: MXCUri) -> Union[LocalMedia, RemoteMedia]: + """Get information about a media item. + Args: + mxc_uri: The MXC URI of the media item. + Returns: + The media information, or None if not found. + """ + server_name = mxc_uri.server_name + media_id = mxc_uri.media_id + media_info: Optional[Union[LocalMedia, RemoteMedia]] + if self._is_mine_server_name(server_name): + media_info = await self.store.get_local_media(media_id) + else: + media_info = await self.store.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id + ) + if not media_info: + raise SynapseError(404, "Media not found", errcode="M_NOT_FOUND") + if media_info.quarantined_by: + raise SynapseError(404, "Media not found", errcode="M_NOT_FOUND") + return media_info + + async def create_or_update_content( + self, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: UserID, + media_id: Optional[str] = None, + restricted: bool = False, + ) -> MXCUri: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def copy_media( + self, existing_mxc: MXCUri, auth_user: UserID, max_timeout_ms: int + ) -> MXCUri: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + @trace + async def _generate_thumbnails( + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def create_media_id( + self, auth_user: UserID, restricted: bool = False + ) -> Tuple[str, int]: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def delete_local_media_ids( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def delete_old_local_media( + self, + before_ts: int, + size_gt: int = 0, + keep_profiles: bool = True, + delete_quarantined_media: bool = False, + delete_protected_media: bool = False, + ) -> Tuple[List[str], int]: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def get_local_media( + self, + request: SynapseRequest, + media_id: str, + name: Optional[str], + max_timeout_ms: int, + requester: Optional[Requester] = None, + allow_authenticated: bool = True, + federation: bool = False, + ) -> None: + raise NotImplementedError( + "Sorry Mario, your MediaRepository related function is in another castle" + ) + + async def validate_media_restriction( + self, + request: SynapseRequest, + media_info: Optional[LocalMedia], + media_id: Optional[str], + is_federation: bool = False, + ) -> Optional[MediaRestrictions]: + """ + MSC3911: If media is restricted but restriction is empty, the media is in + pending state and only creator can see it until it is attached to an event. If + there is a restriction return MediaRestrictions after validation. + + Args: + request: The incoming request. + media_info: Optional, the media information. + media_id: Optional, the media ID to validate. + + Returns: + MediaRestrictions if there is one set, otherwise raise SynapseError. + """ + if not media_info and media_id: + media_info = await self.store.get_local_media(media_id) + if not media_info: + return None + restricted = media_info.restricted + if not restricted: + return None + attachments: Optional[MediaRestrictions] = media_info.attachments + # for both federation and client endpoints + if attachments: + # Only one of event_id or profile_user_id must be set, not both, not neither + if attachments.event_id is None and attachments.profile_user_id is None: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "MediaRestrictions must have exactly one of event_id or profile_user_id set.", + errcode=Codes.FORBIDDEN, + ) + if bool(attachments.event_id) == bool(attachments.profile_user_id): + raise SynapseError( + HTTPStatus.FORBIDDEN, + "MediaRestrictions must have exactly one of event_id or profile_user_id set.", + errcode=Codes.FORBIDDEN, + ) + + if not attachments and is_federation: + raise SynapseError( + HTTPStatus.NOT_FOUND, + "Not found '%s'" % (request.path.decode(),), + errcode=Codes.NOT_FOUND, + ) + + if not attachments and not is_federation: + if ( + isinstance(request.requester, Requester) + and request.requester.user.to_string() != media_info.user_id + ): + raise SynapseError( + HTTPStatus.NOT_FOUND, + "Not found '%s'" % (request.path.decode(),), + errcode=Codes.NOT_FOUND, + ) + else: + return None + return attachments + + async def is_media_visible( + self, requesting_user: UserID, media_info_object: Union[LocalMedia, RemoteMedia] + ) -> None: + """ + Verify that media requested for download should be visible to the user making + the request + """ + + if not self.enable_media_restriction: + return + + if not media_info_object.restricted: + return + + if not media_info_object.attachments: + # When the media has not been attached yet, only the originating user can + # see it. But once attachments have been formed, standard other rules apply + if isinstance(media_info_object, LocalMedia) and ( + requesting_user.to_string() == str(media_info_object.user_id) + ): + return + + # It was restricted, but no attachments. Deny + raise UnauthorizedRequestAPICallError( + f"Media requested ('{media_info_object.media_id}') is restricted" + ) + + attached_event_id = media_info_object.attachments.event_id + attached_profile_user_id = media_info_object.attachments.profile_user_id + + if attached_event_id: + event_base = await self.store.get_event(attached_event_id) + storage_controllers = self.hs.get_storage_controllers() + if event_base.is_state(): + # The standard event visibility utility, filter_events_for_client(), + # does not seem to meet the needs of a good UX when restricting and + # allowing media. This is a very, very simple version to be used for + # state events. + + # First we will collect the current membership of the user for the room + # the relevant event came from. Then we will collect the membership and + # m.room.history_visibility event at the time of the relevant event. + + # Since it is hard to find a relevant place in which to search back in + # time to find out if a given room ever had anything other than a leave + # event, this is the simplest without having to do tablescans + + # Need membership of NOW + ( + membership_now, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + requesting_user.to_string(), event_base.room_id + ) + + if not membership_now: + membership_now = Membership.LEAVE + + membership_state_key = (EventTypes.Member, requesting_user.to_string()) + types = (_HISTORY_VIS_KEY, membership_state_key) + # and history visibility and membership of THEN + event_id_to_state = ( + await storage_controllers.state.get_state_for_events( + [attached_event_id], + state_filter=StateFilter.from_types(types), + ) + ) + + state_map = event_id_to_state.get(attached_event_id) + # Do we need to guard against not having state of a room? + assert state_map is not None + + visibility = get_effective_room_visibility_from_state(state_map) + + memb_then_evt = state_map.get(membership_state_key) + membership_then = Membership.LEAVE + if memb_then_evt: + membership_then = memb_then_evt.content.get( + "membership", Membership.LEAVE + ) + + # Have a few numbers ready for comparison below. These resolve to int + # The index of the visibility present from the event + visibility_priority = VISIBILITY_PRIORITY.index(visibility) + membership_priority_now = MEMBERSHIP_PRIORITY.index(membership_now) + membership_priority_then = MEMBERSHIP_PRIORITY.index(membership_then) + + # These are essentially constants, in that they should not change + world_readable_index = VISIBILITY_PRIORITY.index( + HistoryVisibility.WORLD_READABLE + ) + shared_visibility_index = VISIBILITY_PRIORITY.index( + HistoryVisibility.SHARED + ) + mem_leave_index = MEMBERSHIP_PRIORITY.index(Membership.LEAVE) + + # I disagree with this. 'Shared' by spec implies that some sort of + # positive membership event took place, but the stock + # filter_events_for_client() seems to treat SHARED like WORLD_READABLE, + # so at least this matches + if visibility_priority in [ + world_readable_index, + shared_visibility_index, + ]: + # world readable should always be seen + return + + # If the room is invite visible, and the user is invited, move on + if visibility_priority == VISIBILITY_PRIORITY.index( + HistoryVisibility.INVITED + ) and membership_priority_now == MEMBERSHIP_PRIORITY.index( + Membership.INVITE + ): + return + + # The visibility of the room is shared or greater, so requires at + # the minimum a 'knock' level. Make sure the membership of the user + # is better than leave + if ( + visibility_priority >= shared_visibility_index + and membership_priority_now < mem_leave_index + ): + return + + # Cover the case that a user has left a room but still should see any + # media they were allowed to see prior + # The visibility of the room is shared or greater, so requires at + # the minimum a 'knock' level. Make sure the membership of the user + # is better than leave + if ( + visibility_priority >= shared_visibility_index + and membership_priority_then < mem_leave_index + ): + return + + else: + filtered_events = await filter_events_for_client( + storage_controllers, + requesting_user.to_string(), + [event_base], + ) + if len(filtered_events) > 0: + return + + elif attached_profile_user_id: + # Can this user see that profile? + + # The error returns here may not be suitable, use the work around below + # If shared room restricted profile lookups, it will be restricted + # to users that share rooms + # await self.profile_handler.check_profile_query_allowed( + # restrictions.profile_user_id, requester.user + # ) + # return + + if self.hs.config.server.limit_profile_requests_to_users_who_share_rooms: + # First take care of the case where the requesting user IS the creating + # user. The other function below does not handle this. + if requesting_user.to_string() == attached_profile_user_id.to_string(): + return + + # This call returns a set() that contains which of the "other_user_ids" + # share a room. Since we give it only one, if bool(set()) is True, then they + # share some room or had at least one invite between them. + if not await self.store.do_users_share_a_room_joined_or_invited( + requesting_user.to_string(), + [attached_profile_user_id.to_string()], + ): + raise UnauthorizedRequestAPICallError( + f"Media requested ('{media_info_object.media_id}') is restricted" + ) + + # check these settings: + # * allow_profile_lookup_over_federation + + # If 'limit_profile_requests_to_users_who_share_rooms' is not enabled, all + # bets are kinda off + return + + # It was a third unknown restriction, or otherwise did not pass inspection + raise UnauthorizedRequestAPICallError( + f"Media requested ('{media_info_object.media_id}') is restricted" + ) + + +class MediaRepositoryWorker(AbstractMediaRepository): + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + # initialize replication endpoint here + self.copy_media_client = ReplicationCopyMediaServlet.make_client(hs) + + async def copy_media( + self, existing_mxc: MXCUri, auth_user: UserID, max_timeout_ms: int + ) -> MXCUri: + """ + Call out to the worker responsible for handling media to copy this media object + """ + result = await self.copy_media_client( + instance_name=self.hs.config.worker.workers_doing_media_duty[0], + server_name=existing_mxc.server_name, + media_id=existing_mxc.media_id, + user_id=auth_user.to_string(), + max_timeout_ms=max_timeout_ms, + ) + return MXCUri.from_str(result["content_uri"]) + + +class MediaRepository(AbstractMediaRepository): + def __init__(self, hs: "HomeServer"): + super().__init__(hs) self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels self.unused_expiration_time = hs.config.media.unused_expiration_time @@ -199,8 +606,6 @@ def __init__(self, hs: "HomeServer"): key=lambda limit: limit.time_period_ms, reverse=True ) - self.enable_media_restriction = self.hs.config.experimental.msc3911_enabled - def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( "update_recently_accessed_media", self._update_recently_accessed @@ -414,7 +819,53 @@ async def create_or_update_content( except Exception as e: logger.info("Failed to generate thumbnails: %s", e) - return MXCUri(self.server_name, media_id) + return MXCUri(self.server_name, media_id) + + async def copy_media( + self, existing_mxc: MXCUri, auth_user: UserID, max_timeout_ms: int + ) -> MXCUri: + """ + Copy an existing piece of media into a new file with new LocalMedia + + Args: + existing_mxc: The existing media information + auth_user: The UserID of the user making the request + max_timeout_ms: The millisecond timeout for retrieving existing media info + """ + + old_media_info = await self.get_media_info(existing_mxc) + if isinstance(old_media_info, RemoteMedia): + file_info = FileInfo( + server_name=old_media_info.media_origin, file_id=old_media_info.media_id + ) + else: + file_info = FileInfo(server_name=None, file_id=old_media_info.media_id) + + # This will ensure that if there is another storage provider containing our old + # media, it will be in our local cache before the copy takes place. + # Conveniently, it also gives us the local path of where the file lives. + local_path = await self.media_storage.ensure_media_is_in_local_cache(file_info) + + assert old_media_info.media_length is not None + + # It may end up being that this needs to be pushed down into the MediaStorage + # class. It needs abstraction badly, but that is beyond me at the moment. + io_object = open(local_path, "rb") + + # Let existing methods handle creating the new file for us. By not passing a + # media id, one will be created. + new_mxc_uri = await self.create_or_update_content( + media_type=old_media_info.media_type, + upload_name=old_media_info.upload_name, + content=io_object, + content_length=old_media_info.media_length, + auth_user=auth_user, + restricted=True, + ) + # I could not find a place this was close()'d explicitly, but this felt prudent + io_object.close() + + return new_mxc_uri def respond_not_yet_uploaded(self, request: SynapseRequest) -> None: respond_with_json( @@ -480,190 +931,6 @@ async def get_local_media_info( self.respond_not_yet_uploaded(request) return None - async def is_media_visible( - self, requesting_user: UserID, media_info_object: Union[LocalMedia, RemoteMedia] - ) -> None: - """ - Verify that media requested for download should be visible to the user making - the request - """ - - if not self.enable_media_restriction: - return - - if not media_info_object.restricted: - return - - if not media_info_object.attachments: - # When the media has not been attached yet, only the originating user can - # see it. But once attachments have been formed, standard other rules apply - if isinstance(media_info_object, LocalMedia) and ( - requesting_user.to_string() == str(media_info_object.user_id) - ): - return - - # It was restricted, but no attachments. Deny - raise UnauthorizedRequestAPICallError( - f"Media requested ('{media_info_object.media_id}') is restricted" - ) - - attached_event_id = media_info_object.attachments.event_id - attached_profile_user_id = media_info_object.attachments.profile_user_id - - if attached_event_id: - event_base = await self.store.get_event(attached_event_id) - storage_controllers = self.hs.get_storage_controllers() - if event_base.is_state(): - # The standard event visibility utility, filter_events_for_client(), - # does not seem to meet the needs of a good UX when restricting and - # allowing media. This is a very, very simple version to be used for - # state events. - - # First we will collect the current membership of the user for the room - # the relevant event came from. Then we will collect the membership and - # m.room.history_visibility event at the time of the relevant event. - - # Since it is hard to find a relevant place in which to search back in - # time to find out if a given room ever had anything other than a leave - # event, this is the simplest without having to do tablescans - - # Need membership of NOW - ( - membership_now, - _, - ) = await self.store.get_local_current_membership_for_user_in_room( - requesting_user.to_string(), event_base.room_id - ) - - if not membership_now: - membership_now = Membership.LEAVE - - membership_state_key = (EventTypes.Member, requesting_user.to_string()) - types = (_HISTORY_VIS_KEY, membership_state_key) - # and history visibility and membership of THEN - event_id_to_state = ( - await storage_controllers.state.get_state_for_events( - [attached_event_id], - state_filter=StateFilter.from_types(types), - ) - ) - - state_map = event_id_to_state.get(attached_event_id) - # Do we need to guard against not having state of a room? - assert state_map is not None - - visibility = get_effective_room_visibility_from_state(state_map) - - memb_then_evt = state_map.get(membership_state_key) - membership_then = Membership.LEAVE - if memb_then_evt: - membership_then = memb_then_evt.content.get( - "membership", Membership.LEAVE - ) - - # Have a few numbers ready for comparison below. These resolve to int - # The index of the visibility present from the event - visibility_priority = VISIBILITY_PRIORITY.index(visibility) - membership_priority_now = MEMBERSHIP_PRIORITY.index(membership_now) - membership_priority_then = MEMBERSHIP_PRIORITY.index(membership_then) - - # These are essentially constants, in that they should not change - world_readable_index = VISIBILITY_PRIORITY.index( - HistoryVisibility.WORLD_READABLE - ) - shared_visibility_index = VISIBILITY_PRIORITY.index( - HistoryVisibility.SHARED - ) - mem_leave_index = MEMBERSHIP_PRIORITY.index(Membership.LEAVE) - - # I disagree with this. 'Shared' by spec implies that some sort of - # positive membership event took place, but the stock - # filter_events_for_client() seems to treat SHARED like WORLD_READABLE, - # so at least this matches - if visibility_priority in [ - world_readable_index, - shared_visibility_index, - ]: - # world readable should always be seen - return - - # If the room is invite visible, and the user is invited, move on - if visibility_priority == VISIBILITY_PRIORITY.index( - HistoryVisibility.INVITED - ) and membership_priority_now == MEMBERSHIP_PRIORITY.index( - Membership.INVITE - ): - return - - # The visibility of the room is shared or greater, so requires at - # the minimum a 'knock' level. Make sure the membership of the user - # is better than leave - if ( - visibility_priority >= shared_visibility_index - and membership_priority_now < mem_leave_index - ): - return - - # Cover the case that a user has left a room but still should see any - # media they were allowed to see prior - # The visibility of the room is shared or greater, so requires at - # the minimum a 'knock' level. Make sure the membership of the user - # is better than leave - if ( - visibility_priority >= shared_visibility_index - and membership_priority_then < mem_leave_index - ): - return - - else: - filtered_events = await filter_events_for_client( - storage_controllers, - requesting_user.to_string(), - [event_base], - ) - if len(filtered_events) > 0: - return - - elif attached_profile_user_id: - # Can this user see that profile? - - # The error returns here may not be suitable, use the work around below - # If shared room restricted profile lookups, it will be restricted - # to users that share rooms - # await self.profile_handler.check_profile_query_allowed( - # restrictions.profile_user_id, requester.user - # ) - # return - - if self.hs.config.server.limit_profile_requests_to_users_who_share_rooms: - # First take care of the case where the requesting user IS the creating - # user. The other function below does not handle this. - if requesting_user.to_string() == attached_profile_user_id.to_string(): - return - - # This call returns a set() that contains which of the "other_user_ids" - # share a room. Since we give it only one, if bool(set()) is True, then they - # share some room or had at least one invite between them. - if not await self.store.do_users_share_a_room_joined_or_invited( - requesting_user.to_string(), - [attached_profile_user_id.to_string()], - ): - raise UnauthorizedRequestAPICallError( - f"Media requested ('{media_info_object.media_id}') is restricted" - ) - - # check these settings: - # * allow_profile_lookup_over_federation - - # If 'limit_profile_requests_to_users_who_share_rooms' is not enabled, all - # bets are kinda off - return - - # It was a third unknown restriction, or otherwise did not pass inspection - raise UnauthorizedRequestAPICallError( - f"Media requested ('{media_info_object.media_id}') is restricted" - ) - async def get_local_media( self, request: SynapseRequest, @@ -1844,68 +2111,3 @@ async def _remove_local_media_from_disk( removed_media.append(media_id) return removed_media, len(removed_media) - - async def validate_media_restriction( - self, - request: SynapseRequest, - media_info: Optional[LocalMedia], - media_id: Optional[str], - is_federation: bool = False, - ) -> Optional[MediaRestrictions]: - """ - MSC3911: If media is restricted but restriction is empty, the media is in - pending state and only creator can see it until it is attached to an event. If - there is a restriction return MediaRestrictions after validation. - - Args: - request: The incoming request. - media_info: Optional, the media information. - media_id: Optional, the media ID to validate. - - Returns: - MediaRestrictions if there is one set, otherwise raise SynapseError. - """ - if not media_info and media_id: - media_info = await self.store.get_local_media(media_id) - if not media_info: - return None - restricted = media_info.restricted - if not restricted: - return None - attachments: Optional[MediaRestrictions] = media_info.attachments - # for both federation and client endpoints - if attachments: - # Only one of event_id or profile_user_id must be set, not both, not neither - if attachments.event_id is None and attachments.profile_user_id is None: - raise SynapseError( - HTTPStatus.FORBIDDEN, - "MediaRestrictions must have exactly one of event_id or profile_user_id set.", - errcode=Codes.FORBIDDEN, - ) - if bool(attachments.event_id) == bool(attachments.profile_user_id): - raise SynapseError( - HTTPStatus.FORBIDDEN, - "MediaRestrictions must have exactly one of event_id or profile_user_id set.", - errcode=Codes.FORBIDDEN, - ) - - if not attachments and is_federation: - raise SynapseError( - HTTPStatus.NOT_FOUND, - "Not found '%s'" % (request.path.decode(),), - errcode=Codes.NOT_FOUND, - ) - - if not attachments and not is_federation: - if ( - isinstance(request.requester, Requester) - and request.requester.user.to_string() != media_info.user_id - ): - raise SynapseError( - HTTPStatus.NOT_FOUND, - "Not found '%s'" % (request.path.decode(),), - errcode=Codes.NOT_FOUND, - ) - else: - return None - return attachments diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index ab2e6707cd..4e2473dcfc 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -28,6 +28,7 @@ devices, federation, login, + media, membership, presence, push, @@ -61,6 +62,7 @@ def register_servlets(self, hs: "HomeServer") -> None: push.register_servlets(hs, self) state.register_servlets(hs, self) devices.register_servlets(hs, self) + media.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/media.py b/synapse/replication/http/media.py new file mode 100644 index 0000000000..42ecfcd9d4 --- /dev/null +++ b/synapse/replication/http/media.py @@ -0,0 +1,100 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# [This file includes modifications made by New Vector Limited] +# +# + +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from matrix_common.types.mxc_uri import MXCUri + +from twisted.web.server import Request + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationCopyMediaServlet(ReplicationEndpoint): + """Request the MediaRepository to make a copy of a piece of media. + + Request format: + + POST /_synapse/replication/copy_media/:server_name/:media_id + + { + "user_id": UserID.to_string(), + "max_timeout_ms": int of how long to wait + } + + """ + + NAME = "copy_media" + PATH_ARGS = ("server_name", "media_id") + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.media_repo = hs.get_media_repository() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + server_name: str, + media_id: str, + user_id: str, + max_timeout_ms: int, + ) -> JsonDict: + """ + Args: + server_name: The server_name that originated the media. + media_id: The individualized media id for the origin media. + """ + return {"user_id": user_id, "max_timeout_ms": max_timeout_ms} + + async def _handle_request( # type: ignore[override] + self, + request: Request, + content: JsonDict, + server_name: str, + media_id: str, + ) -> Tuple[int, JsonDict]: + user_id = UserID.from_string(content["user_id"]) + max_timeout_ms = content["max_timeout_ms"] + try: + mxc_uri = MXCUri(server_name=server_name, media_id=media_id) + except ValueError: + # TODO: Make sure the codes here are proper + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "MXC as provided was not formatted correctly", + Codes.INVALID_PARAM, + ) + mxc_uri = await self.media_repo.copy_media( + mxc_uri, user_id, max_timeout_ms=max_timeout_ms + ) + return 200, {"content_uri": str(mxc_uri)} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + ReplicationCopyMediaServlet(hs).register(http_server) diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index ada6c34bc4..6a69993eb8 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -24,6 +24,8 @@ import re from typing import Optional, Union +from matrix_common.types.mxc_uri import MXCUri + from synapse.api.errors import ( Codes, NotFoundError, @@ -379,23 +381,16 @@ async def on_POST( if media_info: try: - mxc_uri, _ = await self.media_repo.create_media_id( - requester.user, restricted=True + mxc_uri = await self.media_repo.copy_media( + MXCUri(server_name=server_name, media_id=media_id), + requester.user, + max_timeout_ms=max_timeout_ms, ) - if media_info.media_length and media_info.sha256: - await self.store.update_local_media( - media_id=mxc_uri.split("/")[-1], - media_type=media_info.media_type, - upload_name=media_info.upload_name, - media_length=media_info.media_length, - user_id=requester.user, - sha256=media_info.sha256, - quarantined_by=None, - ) + respond_with_json( request, 200, - {"content_uri": mxc_uri}, + {"content_uri": str(mxc_uri)}, send_cors=True, ) except Exception as e: @@ -405,6 +400,9 @@ async def on_POST( def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: media_repo = hs.get_media_repository() + # None of these endpoints should be mounted on something that isn't a proper media + # worker, so this is safe to make mypy hush + assert isinstance(media_repo, MediaRepository) if hs.config.media.url_preview_enabled: PreviewURLServlet(hs, media_repo, media_repo.media_storage).register( http_server diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py index e863310635..4c1c486583 100644 --- a/synapse/rest/media/create_resource.py +++ b/synapse/rest/media/create_resource.py @@ -38,7 +38,10 @@ class CreateResource(RestServlet): def __init__( - self, hs: "HomeServer", media_repo: "MediaRepository", restricted: bool = False + self, + hs: "HomeServer", + media_repo: "MediaRepository", + restricted: bool = False, ): super().__init__() diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py index 963b9de252..36767faf20 100644 --- a/synapse/rest/media/media_repository_resource.py +++ b/synapse/rest/media/media_repository_resource.py @@ -23,6 +23,7 @@ from synapse.config._base import ConfigError from synapse.http.server import HttpServer, JsonResource +from synapse.media.media_repository import MediaRepository from .config_resource import MediaConfigResource from .create_resource import CreateResource @@ -96,6 +97,7 @@ def __init__(self, hs: "HomeServer"): @staticmethod def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None: media_repo = hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) # Note that many of these should not exist as v1 endpoints, but empirically # a lot of traffic still goes to them. diff --git a/synapse/server.py b/synapse/server.py index 231bd14907..888c9d554e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -127,7 +127,11 @@ SimpleHttpClient, ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient -from synapse.media.media_repository import MediaRepository +from synapse.media.media_repository import ( + AbstractMediaRepository, + MediaRepository, + MediaRepositoryWorker, +) from synapse.metrics import register_threadpool from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi @@ -700,8 +704,10 @@ def get_media_repository_resource(self) -> MediaRepositoryResource: return MediaRepositoryResource(self) @cache_in_self - def get_media_repository(self) -> MediaRepository: - return MediaRepository(self) + def get_media_repository(self) -> AbstractMediaRepository: + if self.config.media.can_load_media_repo: + return MediaRepository(self) + return MediaRepositoryWorker(self) @cache_in_self def get_federation_transport_client(self) -> TransportLayerClient: diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 2f7cf4569b..6edf687938 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -47,6 +47,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media.filepath import MediaFilePaths +from synapse.media.media_repository import MediaRepository from synapse.media.media_storage import MediaStorage, ReadableFileWrapper from synapse.media.storage_provider import FileStorageProviderBackend from synapse.media.thumbnailer import ThumbnailProvider @@ -629,7 +630,7 @@ def test_thumbnail_repeated_thumbnail(self) -> None: info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) assert info is not None file_id = info.filesystem_id - + assert isinstance(self.media_repo, MediaRepository) thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( origin, file_id ) @@ -730,6 +731,7 @@ def test_same_quality(self, method: str, desired_size: int) -> None: content_type = self.test_image.content_type.decode() media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) thumbnail_provider = ThumbnailProvider( self.hs, media_repo, media_repo.media_storage ) diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py index 0ae414d408..0e388c004d 100644 --- a/tests/media/test_url_previewer.py +++ b/tests/media/test_url_previewer.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor +from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer from synapse.util import Clock @@ -69,6 +70,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) assert media_repo.url_previewer is not None self.url_previewer = media_repo.url_previewer diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index f36af877c4..45bb842520 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -18,18 +18,27 @@ # [This file includes modifications made by New Vector Limited] # # +import io import logging import os -from typing import Any, Optional, Tuple +import time +from http import HTTPStatus +from typing import Any, Dict, Optional, Tuple + +from matrix_common.types.mxc_uri import MXCUri from twisted.internet.protocol import Factory from twisted.test.proto_helpers import MemoryReactor from twisted.web.http import HTTPChannel from twisted.web.server import Request +from synapse.api.constants import EventTypes, HistoryVisibility +from synapse.media._base import FileInfo +from synapse.media.media_repository import MediaRepository from synapse.rest import admin -from synapse.rest.client import login, media +from synapse.rest.client import login, media, room from synapse.server import HomeServer +from synapse.types import UserID, create_requester from synapse.util import Clock from tests.http import ( @@ -39,7 +48,7 @@ ) from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, FakeTransport, make_request -from tests.test_utils import SMALL_PNG +from tests.test_utils import SMALL_PNG, SMALL_PNG_SHA256 from tests.unittest import override_config logger = logging.getLogger(__name__) @@ -246,16 +255,16 @@ def test_download_image_race(self) -> None: def _count_remote_media(self) -> int: """Count the number of files in our remote media directory.""" - path = os.path.join( - self.hs.get_media_repository().primary_base_path, "remote_content" - ) + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + path = os.path.join(media_repo.primary_base_path, "remote_content") return sum(len(files) for _, _, files in os.walk(path)) def _count_remote_thumbnails(self) -> int: """Count the number of files in our remote thumbnails directory.""" - path = os.path.join( - self.hs.get_media_repository().primary_base_path, "remote_thumbnail" - ) + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + path = os.path.join(media_repo.primary_base_path, "remote_thumbnail") return sum(len(files) for _, _, files in os.walk(path)) @@ -478,19 +487,488 @@ def test_download_image_race(self) -> None: def _count_remote_media(self) -> int: """Count the number of files in our remote media directory.""" - path = os.path.join( - self.hs.get_media_repository().primary_base_path, "remote_content" - ) + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + path = os.path.join(media_repo.primary_base_path, "remote_content") return sum(len(files) for _, _, files in os.walk(path)) def _count_remote_thumbnails(self) -> int: """Count the number of files in our remote thumbnails directory.""" - path = os.path.join( - self.hs.get_media_repository().primary_base_path, "remote_thumbnail" - ) + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + path = os.path.join(media_repo.primary_base_path, "remote_thumbnail") return sum(len(files) for _, _, files in os.walk(path)) +class CopyRestrictedResourceReplicationTestCase(BaseMultiWorkerStreamTestCase): + """ + Tests copy API when `msc3911_enabled` is configured to be True. + """ + + servlets = [ + # media.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + ] + + # def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # return self.setup_test_homeserver(config=config) + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + config.update( + { + "experimental_features": {"msc3911_enabled": True}, + "media_repo_instances": ["media_worker_1"], + } + ) + config["instance_map"] = { + "main": {"host": "testserv", "port": 8765}, + "media_worker_1": {"host": "testserv", "port": 1001}, + } + + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # self.media_repo = hs.get_media_repository() + self.profile_handler = self.hs.get_profile_handler() + self.user = self.register_user("user", "testpass") + self.user_tok = self.login("user", "testpass") + self.other_user = self.register_user("other", "testpass") + self.other_user_tok = self.login("other", "testpass") + + def make_worker_hs( + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + ) -> HomeServer: + worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) + # Force the media paths onto the replication resource. + worker_hs.get_media_repository_resource().register_servlets( + self._hs_to_site[worker_hs].resource, worker_hs + ) + media.register_servlets(worker_hs, self._hs_to_site[worker_hs].resource) + return worker_hs + + def fetch_media( + self, + hs: HomeServer, + mxc_uri: MXCUri, + access_token: Optional[str] = None, + expected_code: int = 200, + ) -> FakeChannel: + """ + Test retrieving the media. We do not care about the content of the media, just + that the response is correct + """ + channel = make_request( + self.reactor, + self._hs_to_site[hs], + "GET", + f"/_matrix/client/v1/media/download/{mxc_uri.server_name}/{mxc_uri.media_id}", + access_token=access_token, + ) + assert channel.code == expected_code, channel.code + return channel + + def test_copy_local_restricted_resource(self) -> None: + """ + Tests that the new copy endpoint creates a new mxc uri for restricted resource. + """ + media_worker = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "media_worker_1"} + ) + media_repo = media_worker.get_media_repository() + + # Create a private room + room_id = self.helper.create_room_as( + self.user, + is_public=False, + tok=self.user_tok, + extra_content={ + "initial_state": [ + { + "type": EventTypes.RoomHistoryVisibility, + "state_key": "", + "content": {"history_visibility": HistoryVisibility.JOINED}, + }, + ] + }, + ) + # Invite the other user + self.helper.invite(room_id, self.user, self.other_user, tok=self.user_tok) + self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + # The media is created with user_tok + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(self.user), + restricted=True, + ) + ) + media_id = content_uri.media_id + + # User sends a message with media + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{str(time.time())}?org.matrix.msc3911.attach_media={str(content_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + assert "event_id" in channel.json_body + event_id = channel.json_body["event_id"] + restrictions = self.get_success( + self.hs.get_datastores().main.get_media_restrictions( + content_uri.server_name, content_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + + # The other_user copies the media from local server + channel = make_request( + self.reactor, + self._hs_to_site[media_worker], + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 200) + self.assertIn("content_uri", channel.json_body) + new_media_id = channel.json_body["content_uri"].split("/")[-1] + assert new_media_id != media_id + + # Check if the original media there. + original_media = self.get_success( + self.hs.get_datastores().main.get_local_media(media_id) + ) + assert original_media is not None + assert original_media.user_id == self.user + + # Check the copied media. + copied_media = self.get_success( + self.hs.get_datastores().main.get_local_media(new_media_id) + ) + assert copied_media is not None + assert copied_media.user_id == self.other_user + + # Check if they are referencing the same image. + assert original_media.sha256 == copied_media.sha256 + + # Check if media is unattached to any event or user profile yet. + assert copied_media.attachments is None + + original_media_download = self.fetch_media( + media_worker, + MXCUri.from_str(f"mxc://{self.hs.hostname}/{media_id}"), + self.user_tok, + ) + # This is a hex encoded byte stream of the raw file + old_media_payload = original_media_download.result.get("body") + assert old_media_payload is not None, old_media_payload + + new_media_download = self.fetch_media( + media_worker, + MXCUri.from_str(f"mxc://{self.hs.hostname}/{new_media_id}"), + self.other_user_tok, + ) + # Again, a hex encoded byte stream of the raw file + new_media_payload = new_media_download.result.get("body") + assert new_media_payload is not None + + # If they match, this was a successful copy + assert old_media_payload == new_media_payload + + def test_copy_local_restricted_resource_fails_when_requester_does_not_have_access( + self, + ) -> None: + """ + Tests that the new copy endpoint performs permission checks and it prevents the + copy when the requester does not have access to the original media. + """ + media_worker = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "media_worker_1"} + ) + media_repo = media_worker.get_media_repository() + + # Create a private room + room_id = self.helper.create_room_as( + self.user, + is_public=False, + tok=self.user_tok, + extra_content={ + "initial_state": [ + { + "type": EventTypes.RoomHistoryVisibility, + "state_key": "", + "content": {"history_visibility": HistoryVisibility.JOINED}, + }, + ] + }, + ) + + # Create the media content + content_uri = self.get_success( + media_repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(self.user), + restricted=True, + ) + ) + # User sends a message with media + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{str(time.time())}?org.matrix.msc3911.attach_media={str(content_uri)}", + content={"msgtype": "m.text", "body": "Hi, this is a message"}, + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + assert "event_id" in channel.json_body + event_id = channel.json_body["event_id"] + restrictions = self.get_success( + self.hs.get_datastores().main.get_media_restrictions( + content_uri.server_name, content_uri.media_id + ) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id == event_id + + # Invite the other user + self.helper.invite(room_id, self.user, self.other_user, tok=self.user_tok) + self.helper.join(room_id, self.other_user, tok=self.other_user_tok) + + # User who does not have access to the media tries to copy it. + channel = make_request( + self.reactor, + self._hs_to_site[media_worker], + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{content_uri.media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 403) + + @override_config( + { + "limit_profile_requests_to_users_who_share_rooms": True, + } + ) + def test_copy_local_restricted_resource_fails_when_profile_lookup_is_not_allowed( + self, + ) -> None: + media_worker = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "media_worker_1"} + ) + media_repo = media_worker.get_media_repository() + # User setup a profile + content_uri = self.get_success( + media_repo.create_or_update_content( + "image/png", + "test_png_upload", + io.BytesIO(SMALL_PNG), + 67, + UserID.from_string(self.user), + restricted=True, + ) + ) + user_id = UserID.from_string(self.user) + self.get_success( + self.profile_handler.set_avatar_url( + user_id, create_requester(user_id), str(content_uri) + ) + ) + # The users do not share any rooms, and other user tries to copy the profile picture + channel = make_request( + self.reactor, + self._hs_to_site[media_worker], + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{content_uri.media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 403) + + def test_copy_remote_restricted_resource(self) -> None: + """ + Tests that the new copy endpoint creates a new mxc uri for restricted resource. + """ + media_worker = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "media_worker_1"} + ) + media_repo = media_worker.get_media_repository() + # create remote media + remote_server = "remoteserver.com" + media_id = "remotemedia" + remote_file_id = media_id + file_info = FileInfo(server_name=remote_server, file_id=remote_file_id) + + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage + ctx = media_storage.store_into_file(file_info) + (f, _) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + self.get_success( + # The main store will not have authenticated media enabled, use the media repo + media_repo.store.store_cached_remote_media( + origin=remote_server, + media_id=media_id, + media_type="image/png", + media_length=67, + time_now_ms=self.clock.time_msec(), + upload_name="test.png", + filesystem_id=remote_file_id, + sha256=SMALL_PNG_SHA256, + restricted=True, + ) + ) + + # Remote media is attached to a user profile + remote_user_id = f"@remote-user:{remote_server}" + self.get_success( + self.hs.get_datastores().main.set_media_restricted_to_user_profile( + remote_server, media_id, remote_user_id + ) + ) + remote_media = self.get_success( + self.hs.get_datastores().main.get_cached_remote_media( + remote_server, media_id + ) + ) + assert remote_media is not None + assert remote_media.attachments is not None + assert str(remote_media.attachments.profile_user_id) == remote_user_id + + # The other_user copies the media from remote server + channel = make_request( + self.reactor, + self._hs_to_site[media_worker], + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{remote_server}/{media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 200) + self.assertIn("content_uri", channel.json_body) + new_media_id = channel.json_body["content_uri"].split("/")[-1] + assert new_media_id != media_id + + # Check if the original media there. + original_media = self.get_success( + self.hs.get_datastores().main.get_cached_remote_media( + remote_server, media_id + ) + ) + assert original_media is not None + assert original_media.upload_name == "test.png" + + # Check the copied media. + copied_media = self.get_success( + self.hs.get_datastores().main.get_local_media(new_media_id) + ) + assert copied_media is not None + assert copied_media.user_id == self.other_user + + # Check if they are referencing the same image. + assert original_media.sha256 == copied_media.sha256 + + # Check if copied media is unattached to any event or user profile yet. + assert copied_media.attachments is None + + original_media_download = self.fetch_media( + media_worker, + MXCUri.from_str(f"mxc://{remote_server}/{media_id}"), + self.user_tok, + ) + # This is a hex encoded byte stream of the raw file + old_media_payload = original_media_download.result.get("body") + assert old_media_payload is not None, old_media_payload + + new_media_download = self.fetch_media( + media_worker, + MXCUri.from_str(f"mxc://{self.hs.hostname}/{new_media_id}"), + self.other_user_tok, + ) + # Again, a hex encoded byte stream of the raw file + new_media_payload = new_media_download.result.get("body") + assert new_media_payload is not None + + # If they match, this was a successful copy + assert old_media_payload == new_media_payload + + @override_config( + { + "limit_profile_requests_to_users_who_share_rooms": True, + } + ) + def test_copy_remote_restricted_resource_fails_when_requester_does_not_have_access( + self, + ) -> None: + media_worker = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "media_worker_1"} + ) + media_repo = media_worker.get_media_repository() + + # Create remote media + remote_server = "remoteserver.com" + remote_file_id = "remote1" + file_info = FileInfo(server_name=remote_server, file_id=remote_file_id) + + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage + ctx = media_storage.store_into_file(file_info) + (f, _) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + media_id = "remotemedia" + self.get_success( + # The main data store will not have authenticated media enabled, use the media repo + media_repo.store.store_cached_remote_media( + origin=remote_server, + media_id=media_id, + media_type="image/png", + media_length=1, + time_now_ms=self.clock.time_msec(), + upload_name="test.png", + filesystem_id=remote_file_id, + sha256=remote_file_id, + restricted=True, + ) + ) + + # Media is attached to a user profile + remote_user_id = f"@remote-user:{remote_server}" + self.get_success( + self.hs.get_datastores().main.set_media_restricted_to_user_profile( + remote_server, media_id, remote_user_id + ) + ) + remote_media = self.get_success( + self.hs.get_datastores().main.get_cached_remote_media( + remote_server, media_id + ) + ) + assert remote_media is not None + assert remote_media.attachments is not None + assert str(remote_media.attachments.profile_user_id) == remote_user_id + + # The other user tries to copy that media from remote server, but fails because + # user does not have the access to the profile_user_id + channel = make_request( + self.reactor, + self._hs_to_site[media_worker], + "POST", + f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{remote_server}/{media_id}", + access_token=self.other_user_tok, + ) + self.assertEqual(channel.code, 403) + + def _log_request(request: Request) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info("Completed request %s", request) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index e2a25efc26..3e91c76308 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -60,6 +60,7 @@ from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable from synapse.media._base import FileInfo, ThumbnailInfo +from synapse.media.media_repository import MediaRepository from synapse.media.thumbnailer import ThumbnailProvider from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.rest import admin @@ -84,7 +85,7 @@ small_png_with_transparency, ) from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock -from tests.test_utils import SMALL_PNG +from tests.test_utils import SMALL_PNG, SMALL_PNG_SHA256 from tests.unittest import override_config try: @@ -133,8 +134,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # from a regular 404. file_id = "abcdefg12345" file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - - media_storage = hs.get_media_repository().media_storage + media_repo = hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage ctx = media_storage.store_into_file(file_info) (f, fname) = self.get_success(ctx.__aenter__()) @@ -273,6 +275,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository() + assert isinstance(self.media_repo, MediaRepository) assert self.media_repo.url_previewer is not None self.url_previewer = self.media_repo.url_previewer @@ -1405,6 +1408,7 @@ def test_storage_providers_exclude_files(self) -> None: """Test that files are not stored in or fetched from storage providers.""" host, media_id = self._download_image() + assert isinstance(self.media_repo, MediaRepository) rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id) media_store_path = os.path.join(self.media_store_path, rel_file_path) storage_provider_path = os.path.join(self.storage_path, rel_file_path) @@ -1447,6 +1451,7 @@ def test_storage_providers_exclude_thumbnails(self) -> None: """Test that thumbnails are not stored in or fetched from storage providers.""" host, media_id = self._download_image() + assert isinstance(self.media_repo, MediaRepository) rel_thumbnail_path = ( self.media_repo.filepaths.url_cache_thumbnail_directory_rel(media_id) ) @@ -1475,6 +1480,7 @@ def test_storage_providers_exclude_thumbnails(self) -> None: self.assertEqual(channel.code, 200) # Remove the original, otherwise thumbnails will regenerate + assert isinstance(self.media_repo, MediaRepository) rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id) media_store_path = os.path.join(self.media_store_path, rel_file_path) os.remove(media_store_path) @@ -1500,6 +1506,7 @@ def test_cache_expiry(self) -> None: """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" _host, media_id = self._download_image() + assert isinstance(self.media_repo, MediaRepository) file_path = self.media_repo.filepaths.url_cache_filepath(media_id) file_dirs = self.media_repo.filepaths.url_cache_filepath_dirs_to_delete( media_id @@ -2404,6 +2411,7 @@ def test_thumbnail_repeated_thumbnail(self) -> None: assert info is not None file_id = info.filesystem_id + assert isinstance(self.media_repo, MediaRepository) thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( self.remote, file_id ) @@ -2506,6 +2514,7 @@ def test_same_quality(self, method: str, desired_size: int) -> None: content_type = self.test_image.content_type.decode() media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) thumbnail_provider = ThumbnailProvider( self.hs, media_repo, media_repo.media_storage ) @@ -2646,7 +2655,9 @@ def test_authenticated_media(self) -> None: file_id = "abcdefg12345" file_info = FileInfo(server_name="lonelyIsland", file_id=file_id) - media_storage = self.hs.get_media_repository().media_storage + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage ctx = media_storage.store_into_file(file_info) (f, fname) = self.get_success(ctx.__aenter__()) @@ -2779,7 +2790,9 @@ def test_authenticated_media_etag(self) -> None: file_id = "abcdefg12345" file_info = FileInfo(server_name="lonelyIsland", file_id=file_id) - media_storage = self.hs.get_media_repository().media_storage + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage ctx = media_storage.store_into_file(file_info) (f, fname) = self.get_success(ctx.__aenter__()) @@ -3191,7 +3204,7 @@ def test_async_upload_restricted_resource(self) -> None: assert channel.code == 403, channel.json_body -class CopyRestrictedResource(unittest.HomeserverTestCase): +class CopyRestrictedResourceTestCase(unittest.HomeserverTestCase): """ Tests copy API when `msc3911_enabled` is configured to be True. """ @@ -3225,6 +3238,24 @@ def create_resource_dict(self) -> dict[str, Resource]: resources["/_matrix/media"] = self.hs.get_media_repository_resource() return resources + def fetch_media( + self, + mxc_uri: MXCUri, + access_token: Optional[str] = None, + expected_code: int = 200, + ) -> FakeChannel: + """ + Test retrieving the media. We do not care about the content of the media, just + that the response is correct + """ + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{mxc_uri.server_name}/{mxc_uri.media_id}", + access_token=access_token, + ) + assert channel.code == expected_code, channel.code + return channel + def test_copy_local_restricted_resource(self) -> None: """ Tests that the new copy endpoint creates a new mxc uri for restricted resource. @@ -3311,6 +3342,24 @@ def test_copy_local_restricted_resource(self) -> None: # Check if media is unattached to any event or user profile yet. assert copied_media.attachments is None + original_media_download = self.fetch_media( + MXCUri.from_str(f"mxc://{self.hs.hostname}/{media_id}"), self.user_tok + ) + # This is a hex encoded byte stream of the raw file + old_media_payload = original_media_download.result.get("body") + assert old_media_payload is not None, old_media_payload + + new_media_download = self.fetch_media( + MXCUri.from_str(f"mxc://{self.hs.hostname}/{new_media_id}"), + self.other_user_tok, + ) + # Again, a hex encoded byte stream of the raw file + new_media_payload = new_media_download.result.get("body") + assert new_media_payload is not None + + # If they match, this was a successful copy + assert old_media_payload == new_media_payload + def test_copy_local_restricted_resource_fails_when_requester_does_not_have_access( self, ) -> None: @@ -3414,25 +3463,27 @@ def test_copy_remote_restricted_resource(self) -> None: """ # create remote media remote_server = "remoteserver.com" - remote_file_id = "remote1" + media_id = "remotemedia" + remote_file_id = media_id file_info = FileInfo(server_name=remote_server, file_id=remote_file_id) - media_storage = self.hs.get_media_repository().media_storage + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage ctx = media_storage.store_into_file(file_info) (f, _) = self.get_success(ctx.__aenter__()) f.write(SMALL_PNG) self.get_success(ctx.__aexit__(None, None, None)) - media_id = "remotemedia" self.get_success( self.hs.get_datastores().main.store_cached_remote_media( origin=remote_server, media_id=media_id, media_type="image/png", - media_length=1, + media_length=67, time_now_ms=self.clock.time_msec(), upload_name="test.png", filesystem_id=remote_file_id, - sha256=remote_file_id, + sha256=SMALL_PNG_SHA256, restricted=True, ) ) @@ -3486,6 +3537,24 @@ def test_copy_remote_restricted_resource(self) -> None: # Check if copied media is unattached to any event or user profile yet. assert copied_media.attachments is None + original_media_download = self.fetch_media( + MXCUri.from_str(f"mxc://{remote_server}/{media_id}"), self.user_tok + ) + # This is a hex encoded byte stream of the raw file + old_media_payload = original_media_download.result.get("body") + assert old_media_payload is not None, old_media_payload + + new_media_download = self.fetch_media( + MXCUri.from_str(f"mxc://{self.hs.hostname}/{new_media_id}"), + self.other_user_tok, + ) + # Again, a hex encoded byte stream of the raw file + new_media_payload = new_media_download.result.get("body") + assert new_media_payload is not None + + # If they match, this was a successful copy + assert old_media_payload == new_media_payload + @override_config( { "limit_profile_requests_to_users_who_share_rooms": True, @@ -3499,7 +3568,9 @@ def test_copy_remote_restricted_resource_fails_when_requester_does_not_have_acce remote_file_id = "remote1" file_info = FileInfo(server_name=remote_server, file_id=remote_file_id) - media_storage = self.hs.get_media_repository().media_storage + media_repo = self.hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage ctx = media_storage.store_into_file(file_info) (f, _) = self.get_success(ctx.__aenter__()) f.write(SMALL_PNG) diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py index 26453f70dd..e2a6c6a7ce 100644 --- a/tests/rest/media/test_domain_blocking.py +++ b/tests/rest/media/test_domain_blocking.py @@ -24,6 +24,7 @@ from twisted.web.resource import Resource from synapse.media._base import FileInfo +from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer from synapse.util import Clock @@ -44,8 +45,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # from a regular 404. file_id = "abcdefg12345" file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - - media_storage = hs.get_media_repository().media_storage + media_repo = hs.get_media_repository() + assert isinstance(media_repo, MediaRepository) + media_storage = media_repo.media_storage ctx = media_storage.store_into_file(file_info) (f, fname) = self.get_success(ctx.__aenter__()) diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py index 2a7bee19f9..57e5cde6d8 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py @@ -33,6 +33,7 @@ from twisted.web.resource import Resource from synapse.config.oembed import OEmbedEndpointConfig +from synapse.media.media_repository import MediaRepository from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.server import HomeServer from synapse.types import JsonDict @@ -124,6 +125,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository() + assert isinstance(self.media_repo, MediaRepository) assert self.media_repo.url_previewer is not None self.url_previewer = self.media_repo.url_previewer @@ -1265,6 +1267,7 @@ def test_storage_providers_exclude_files(self) -> None: """Test that files are not stored in or fetched from storage providers.""" host, media_id = self._download_image() + assert isinstance(self.media_repo, MediaRepository) rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id) media_store_path = os.path.join(self.media_store_path, rel_file_path) storage_provider_path = os.path.join(self.storage_path, rel_file_path) @@ -1308,6 +1311,7 @@ def test_storage_providers_exclude_thumbnails(self) -> None: """Test that thumbnails are not stored in or fetched from storage providers.""" host, media_id = self._download_image() + assert isinstance(self.media_repo, MediaRepository) rel_thumbnail_path = ( self.media_repo.filepaths.url_cache_thumbnail_directory_rel(media_id) ) @@ -1336,6 +1340,7 @@ def test_storage_providers_exclude_thumbnails(self) -> None: self.assertEqual(channel.code, 200) # Remove the original, otherwise thumbnails will regenerate + assert isinstance(self.media_repo, MediaRepository) rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id) media_store_path = os.path.join(self.media_store_path, rel_file_path) os.remove(media_store_path) @@ -1361,6 +1366,7 @@ def test_cache_expiry(self) -> None: """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" _host, media_id = self._download_image() + assert isinstance(self.media_repo, MediaRepository) file_path = self.media_repo.filepaths.url_cache_filepath(media_id) file_dirs = self.media_repo.filepaths.url_cache_filepath_dirs_to_delete( media_id From 243eb4ac12c7fba39770f4344640618e9fed90dd Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Fri, 12 Sep 2025 17:28:05 +0200 Subject: [PATCH 24/35] MSC3911: AP6 Automatic copy and attach media when updating member events --- synapse/handlers/room_member.py | 19 ++ tests/rest/client/test_media.py | 11 +- tests/rest/client/test_profile.py | 270 ++++++++++++++++ tests/rest/client/test_rooms.py | 507 ++++++++++++++++++++++++++++++ 4 files changed, 801 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c58a327681..d7cfa4c05b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -109,6 +109,7 @@ def __init__(self, hs: "HomeServer"): self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() self._worker_lock_handler = hs.get_worker_locks_handler() + self.enable_restricted_media = hs.config.experimental.msc3911_enabled self._membership_types_to_include_profile_data_in = { Membership.JOIN, @@ -537,6 +538,7 @@ async def _local_membership_update( # we know it was persisted, so should have a stream ordering assert result_event.internal_metadata.stream_ordering + return result_event.event_id, result_event.internal_metadata.stream_ordering async def copy_room_tags_and_direct_to_room( @@ -860,6 +862,23 @@ async def update_membership_locked( except Exception as e: logger.info("Failed to get profile information for %r: %s", target, e) + if self.enable_restricted_media and not media_info_for_attachment: + # Other than membership + avatar_url = content.get("avatar_url") + if avatar_url: + # Something about the MediaRepository does not like being part of + # the initialization code of the RoomMemberHandler, so just import + # it on the spot instead. + media_repo = self.hs.get_media_repository() + + new_mxc_uri = await media_repo.copy_media( + MXCUri.from_str(avatar_url), requester.user, 20_000 + ) + media_object = await media_repo.get_media_info(new_mxc_uri) + assert isinstance(media_object, LocalMedia) + media_info_for_attachment = {media_object} + content[EventContentFields.MEMBERSHIP_AVATAR_URL] = str(new_mxc_uri) + # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index ee449f190e..78141777b5 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -4646,6 +4646,9 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.repo = hs.get_media_repository() + assert isinstance(self.repo, MediaRepository) + self.media_storage = self.repo.media_storage + self.client = hs.get_federation_http_client() self.store = hs.get_datastores().main self.user = self.register_user("user", "pass") @@ -4721,9 +4724,7 @@ def inject_local_media_and_send_event(self, authed: bool, restricted: bool) -> s file_id = media_id file_info = FileInfo(None, file_id=file_id) - media_storage = self.hs.get_media_repository().media_storage - - ctx = media_storage.store_into_file(file_info) + ctx = self.media_storage.store_into_file(file_info) (f, fname) = self.get_success(ctx.__aenter__()) f.write(SMALL_PNG) self.get_success(ctx.__aexit__(None, None, None)) @@ -4784,9 +4785,7 @@ def inject_remote_media(self, restricted: bool) -> str: file_id = media_id file_info = FileInfo(server_name=self.other_server_name, file_id=file_id) - media_storage = self.hs.get_media_repository().media_storage - - ctx = media_storage.store_into_file(file_info) + ctx = self.media_storage.store_into_file(file_info) (f, fname) = self.get_success(ctx.__aenter__()) f.write(SMALL_PNG) self.get_success(ctx.__aexit__(None, None, None)) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 7528c07661..110d6bdd0f 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -36,11 +36,14 @@ from synapse.rest import admin from synapse.rest.client import login, media, profile, room from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, make_request from tests.test_utils import SMALL_PNG from tests.unittest import override_config from tests.utils import USE_POSTGRES_FOR_TESTS @@ -923,6 +926,7 @@ class ProfileMediaAttachmentTestCase(unittest.HomeserverTestCase): login.register_servlets, media.register_servlets, profile.register_servlets, + room.register_servlets, ] def prepare( @@ -966,6 +970,29 @@ def create_media_and_set_restricted_flag(self, user_id: str) -> MXCUri: ) return content_uri + def get_media_id_by_attached_event_id(self, event_id: str) -> Optional[str]: + sql = """ + SELECT media_id + FROM media_attachments + WHERE restrictions_json->'restrictions'->>'event_id' = ?; + """ + + def _get_media_id_by_attached_event_id_txn( + txn: LoggingTransaction, + ) -> Optional[str]: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if not row: + return None + return row[0] + + return self.get_success( + self.store.db_pool.runInteraction( + "get_media_id_by_attached_event_id", + _get_media_id_by_attached_event_id_txn, + ) + ) + def test_can_attach_media_to_profile_update(self) -> None: """ Test basic functionality, that a media ID can be attached to a user profile id. @@ -1180,3 +1207,246 @@ def test_remove_media_from_profile(self) -> None: ) ) assert user_avatar is None + + def test_profile_update_with_media_is_copied_and_attached_to_member_events( + self, + ) -> None: + room_id = self.helper.create_room_as(self.user, is_public=True, tok=self.tok) + self.helper.join(room_id, self.other_user, tok=self.other_tok) + mxc_uri = self.create_media_and_set_restricted_flag(self.user) + + # Attach the media to the user profile. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url?propagate=true", + access_token=self.tok, + content={"avatar_url": str(mxc_uri)}, + ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + + # Check media is set as user avatar. + user_avatar = self.get_success( + self.store.get_profile_avatar_url( + UserID.from_string(self.user), + ) + ) + assert user_avatar is not None + assert user_avatar == str(mxc_uri) + + # Check media restrictions + media_info = self.get_success(self.store.get_local_media(mxc_uri.media_id)) + assert media_info is not None + assert media_info.attachments is not None + assert media_info.attachments.profile_user_id == UserID.from_string(self.user) + + # Check the media was copied and attached to a member event + events = self.get_success( + self.store.get_events_sent_by_user_in_room( + self.user, room_id, 10, ["m.room.member"] + ) + ) + # Get member event id + assert events is not None + member_event_id = events[0] + + copied_media_id = self.get_media_id_by_attached_event_id(member_event_id) + + assert copied_media_id is not None + assert copied_media_id != mxc_uri.media_id + + media_restrictions = self.get_success( + self.store.get_media_restrictions(self.server_name, copied_media_id) + ) + assert media_restrictions is not None + assert media_restrictions.event_id == member_event_id + + +class ProfileMediaAttachmentReplicationTestCase(BaseMultiWorkerStreamTestCase): + """ + Test that a membership event that is supposed to copy media appropriately replicates + the call from a generic room worker to the main process where media is handled + """ + + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + # profile.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + + self.user = self.register_user("user", "password") + self.tok = self.login("user", "password") + + self.other_user = self.register_user("other_user", "password") + self.other_tok = self.login("other_user", "password") + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + # config["media_repo_instances"] = [MAIN_PROCESS_INSTANCE_NAME] + + return config + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def create_media_and_set_restricted_flag(self, user_id: str) -> MXCUri: + """ + Create media without using an endpoint, and set the restricted flag. + """ + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(user_id), + restricted=True, + ) + ) + return content_uri + + def make_worker_hs( + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + ) -> HomeServer: + worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) + # Mount the room resource onto the worker. + # worker_hs.get_media_repository_resource().register_servlets( + # self._hs_to_site[worker_hs].resource, worker_hs + # ) + room.register_servlets(worker_hs, self._hs_to_site[worker_hs].resource) + profile.register_servlets(worker_hs, self._hs_to_site[worker_hs].resource) + return worker_hs + + def get_media_id_by_attached_event_id(self, event_id: str) -> Optional[str]: + sql = """ + SELECT media_id + FROM media_attachments + WHERE restrictions_json->'restrictions'->>'event_id' = ?; + """ + + def _get_media_id_by_attached_event_id_txn( + txn: LoggingTransaction, + ) -> Optional[str]: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if not row: + return None + return row[0] + + return self.get_success( + self.store.db_pool.runInteraction( + "get_media_id_by_attached_event_id", + _get_media_id_by_attached_event_id_txn, + ) + ) + + def fetch_media( + self, + mxc_uri: MXCUri, + access_token: Optional[str] = None, + expected_code: int = 200, + ) -> FakeChannel: + """ + Test retrieving the media. Assert the response code, but return the channel, the + test may have use of it. For this test case series, the media repo is the same + as the main process + """ + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{mxc_uri.server_name}/{mxc_uri.media_id}", + access_token=access_token, + ) + assert channel.code == expected_code, channel.code + return channel + + def assert_media_is_identical(self, first_mxc: MXCUri, second_mxc: MXCUri) -> None: + """ + Verify that both media object are actually byte-for-byte identical + """ + first_media = self.fetch_media(first_mxc, self.tok) + first_media_payload = first_media.result.get("body") + assert first_media_payload is not None + + second_media = self.fetch_media(second_mxc, self.tok) + second_media_payload = second_media.result.get("body") + assert second_media_payload is not None + + assert first_media_payload == second_media_payload + + def test_profile_update_with_media_is_copied_and_attached_to_member_events( + self, + ) -> None: + generic_worker = self.make_worker_hs("synapse.app.generic_worker") + + room_id = self.helper.create_room_as(self.user, is_public=True, tok=self.tok) + self.helper.join(room_id, self.other_user, tok=self.other_tok) + mxc_uri = self.create_media_and_set_restricted_flag(self.user) + + # Attach the media to the user profile. + channel = make_request( + self.reactor, + self._hs_to_site[generic_worker], + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url?propagate=true", + access_token=self.tok, + content={"avatar_url": str(mxc_uri)}, + ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + + # Check media is set as user avatar. + user_avatar = self.get_success( + self.store.get_profile_avatar_url( + UserID.from_string(self.user), + ) + ) + assert user_avatar is not None + assert user_avatar == str(mxc_uri) + + # Check media restrictions + media_info = self.get_success(self.store.get_local_media(mxc_uri.media_id)) + assert media_info is not None + assert media_info.attachments is not None + assert media_info.attachments.profile_user_id == UserID.from_string(self.user) + + # Check the media was copied and attached to a member event + events = self.get_success( + self.store.get_events_sent_by_user_in_room( + self.user, room_id, 10, ["m.room.member"] + ) + ) + # Get member event id + assert events is not None + member_event_id = events[0] + + copied_media_id = self.get_media_id_by_attached_event_id(member_event_id) + + assert copied_media_id is not None, copied_media_id + assert copied_media_id != mxc_uri.media_id + + media_restrictions = self.get_success( + self.store.get_media_restrictions(self.server_name, copied_media_id) + ) + assert media_restrictions is not None + assert media_restrictions.event_id == member_event_id + + # the media ID should be different, but the server_name will be for the local + # host. Verify that both media byte streams are the same + new_media_mxc_uri = MXCUri( + server_name=self.server_name, media_id=copied_media_id + ) + self.assert_media_is_identical(mxc_uri, new_media_mxc_uri) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b571453c0d..da1afbc94c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -59,15 +59,18 @@ profile, register, room, + room_upgrade_rest_servlet, sync, ) from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.server import FakeChannel from tests.storage.test_stream import PaginationTestCase from tests.test_utils import SMALL_PNG from tests.test_utils.event_injection import create_event @@ -5265,3 +5268,507 @@ def test_create_room_fails_with_malformed_room_avatar_url(self) -> None: """Test that a malformed room avatar url fails the room creation""" room_id = self.create_room_with_avatar(avatar_mxc="junk", expected_code=400) assert room_id is None + + +class RoomMemberEventMediaAttachmentTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + media.register_servlets, + profile.register_servlets, + knock.register_servlets, + room_upgrade_rest_servlet.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.server_name = self.hs.config.server.server_name + self.media_repo = self.hs.get_media_repository() + + self.user = self.register_user("room_creator", "password") + self.tok = self.login("room_creator", "password") + + self.other_user = self.register_user("invitee", "password") + self.other_tok = self.login("invitee", "password") + + # Use these to set the displaynames on the profiles + self.displaynames = { + self.user: "Room Creator", + self.other_user: "Other User", + } + + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config + + def create_media_and_set_as_profile_avatar(self, user_id: str, tok: str) -> MXCUri: + """ + Create media without using the media endpoints directly, and set the + restricted flag on the media. + """ + content = io.BytesIO(SMALL_PNG) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(user_id), + restricted=True, + ) + ) + + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{user_id}/displayname?propagate=true", + access_token=tok, + content={"displayname": self.displaynames[user_id]}, + ) + assert channel.code == HTTPStatus.OK, channel.result["body"] + assert channel.json_body == {} + + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{user_id}/avatar_url?propagate=true", + access_token=tok, + content={"avatar_url": str(content_uri)}, + ) + assert channel.code == HTTPStatus.OK, channel.result["body"] + assert channel.json_body == {} + + retrieved_restrictions = self.get_success_or_raise( + self.store.get_media_restrictions(self.server_name, content_uri.media_id) + ) + assert retrieved_restrictions is not None + assert str(retrieved_restrictions.profile_user_id) == user_id + + return content_uri + + def get_media_id_by_attached_event_id(self, event_id: str) -> Optional[str]: + sql = """ + SELECT media_id + FROM media_attachments + WHERE restrictions_json->'restrictions'->>'event_id' = ?; + """ + + def _get_media_id_by_attached_event_id_txn( + txn: LoggingTransaction, + ) -> Optional[str]: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if not row: + return None + return row[0] + + return self.get_success( + self.store.db_pool.runInteraction( + "get_media_id_by_attached_event_id", + _get_media_id_by_attached_event_id_txn, + ) + ) + + def fetch_media( + self, + mxc_uri: MXCUri, + access_token: Optional[str] = None, + expected_code: int = 200, + ) -> FakeChannel: + """ + Test retrieving the media. Assert the response code, but return the channel, the + test may have use of it. For this test case series, the media repo is the same + as the main process + """ + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{mxc_uri.server_name}/{mxc_uri.media_id}", + access_token=access_token, + ) + assert channel.code == expected_code, channel.code + return channel + + def assert_media_is_identical( + self, first_mxc: MXCUri, second_mxc: MXCUri, access_token: str + ) -> None: + """ + Verify that both media object are actually byte-for-byte identical + """ + first_media = self.fetch_media(first_mxc, access_token) + first_media_payload = first_media.result.get("body") + assert first_media_payload is not None + + second_media = self.fetch_media(second_mxc, access_token) + second_media_payload = second_media.result.get("body") + assert second_media_payload is not None + + assert first_media_payload == second_media_payload + + def test_join_with_media_get_copied_and_attached_to_event(self) -> None: + """Test that member event by join copies and attaches the media to the event.""" + # Set profile avatar for joining user + joiner_avatar_mxc_uri = self.create_media_and_set_as_profile_avatar( + self.other_user, self.other_tok + ) + + # Create a room and invite the other_user + room_id = self.helper.create_room_as(self.user, is_public=True, tok=self.tok) + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{room_id}/join", + access_token=self.other_tok, + content={"user_id": self.other_user}, + ) + assert channel.code == 200, channel.result["body"] + + # Get the member event of the join just occurred. + joiner_event_ids = self.get_success( + self.store.get_membership_event_ids_for_user(self.other_user, room_id) + ) + assert len(joiner_event_ids) == 1 + + event_id = joiner_event_ids.pop() + assert event_id is not None + event = self.get_success(self.store.get_event(event_id)) + assert event.type == EventTypes.Member + + # Verify that this event is the correct one + assert event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) != str( + joiner_avatar_mxc_uri + ) + + # Verify the display name is unchanged + assert ( + event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) + == self.displaynames[self.other_user] + ) + + # Find the copied media by the event id + copied_media_id = self.get_media_id_by_attached_event_id(event_id) + + assert copied_media_id is not None + assert copied_media_id != joiner_avatar_mxc_uri.media_id + + # Check if copied media has attached to the event + copied_media_restrictions = self.get_success( + self.store.get_media_restrictions(self.server_name, copied_media_id) + ) + assert copied_media_restrictions is not None + assert copied_media_restrictions.event_id == event_id + + # the media ID should be different, but the server_name will be for the local + # host. Verify that both media byte streams are the same + new_media_mxc_uri = MXCUri( + server_name=self.server_name, media_id=copied_media_id + ) + self.assert_media_is_identical( + joiner_avatar_mxc_uri, new_media_mxc_uri, self.other_tok + ) + + def test_invite_with_media_get_copied_and_attached_to_event(self) -> None: + """Test that member event by invite copies and attaches the media to the event.""" + # Set profile avatar for invited user + invitee_avatar_mxc_uri = self.create_media_and_set_as_profile_avatar( + self.other_user, self.other_tok + ) + + # Create a room + room_id = self.helper.create_room_as(self.user, is_public=True, tok=self.tok) + + # Invite the other_user + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{room_id}/invite", + access_token=self.tok, + content={"user_id": self.other_user}, + ) + assert channel.code == 200, channel.result["body"] + + # Get the member event of the invite just occurred. + # creator_event_ids = self.get_success(self.store.get_membership_event_ids_for_user(self.user, room_id)) + # assert len(creator_event_ids) == 1 + invitee_event_ids = self.get_success( + self.store.get_membership_event_ids_for_user(self.other_user, room_id) + ) + assert len(invitee_event_ids) == 1 + + event_id = invitee_event_ids.pop() + assert event_id is not None + event = self.get_success(self.store.get_event(event_id)) + assert event.type == EventTypes.Member + + # Verify that this event a different mxc + assert event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) != str( + invitee_avatar_mxc_uri + ) + + # Verify the display name is unchanged + assert ( + event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) + == self.displaynames[self.other_user] + ) + + # Find the copied media by the event id + copied_media_id = self.get_media_id_by_attached_event_id(event_id) + + assert copied_media_id is not None + assert copied_media_id != invitee_avatar_mxc_uri.media_id + + # Check if copied media has attached to the event + copied_media_restrictions = self.get_success( + self.store.get_media_restrictions(self.server_name, copied_media_id) + ) + assert copied_media_restrictions is not None + assert copied_media_restrictions.event_id == event_id + + # the media ID should be different, but the server_name will be for the local + # host. Verify that both media byte streams are the same + new_media_mxc_uri = MXCUri( + server_name=self.server_name, media_id=copied_media_id + ) + self.assert_media_is_identical( + invitee_avatar_mxc_uri, new_media_mxc_uri, self.other_tok + ) + + def test_knock_with_media_get_copied_and_attached_to_event(self) -> None: + """Test that member event by knock copies and attaches the media to the event.""" + # Set profile avatar for knocking user + knocker_avatar_mxc_uri = self.create_media_and_set_as_profile_avatar( + self.other_user, self.other_tok + ) + + # Create a knockable room + room_id = self.helper.create_room_as(self.user, is_public=True, tok=self.tok) + channel = self.make_request( + "PUT", + f"/rooms/{room_id}/state/m.room.join_rules", + {"join_rule": "knock"}, + access_token=self.tok, + ) + assert channel.code == 200 + + # The other_user knocks on the room + channel = self.make_request( + "POST", + f"/_matrix/client/v3/knock/{room_id}", + access_token=self.other_tok, + content={}, + ) + assert channel.code == 200, channel.result + + # Get the member event of the join just occurred. + knocker_event_ids = self.get_success( + self.store.get_membership_event_ids_for_user(self.other_user, room_id) + ) + assert len(knocker_event_ids) == 1 + + member_event_id = knocker_event_ids.pop() + assert member_event_id is not None + event = self.get_success(self.store.get_event(member_event_id)) + assert event.type == EventTypes.Member + assert event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) != str( + knocker_avatar_mxc_uri + ) + + # Verify the display name is unchanged + assert ( + event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) + == self.displaynames[self.other_user] + ) + + # Find the copied media by the event id + copied_media_id = self.get_media_id_by_attached_event_id(member_event_id) + + assert copied_media_id is not None + assert copied_media_id != knocker_avatar_mxc_uri.media_id + + # Check if copied media has attached to the event + copied_media_restrictions = self.get_success( + self.store.get_media_restrictions(self.server_name, copied_media_id) + ) + assert copied_media_restrictions is not None + assert copied_media_restrictions.event_id == member_event_id + + # the media ID should be different, but the server_name will be for the local + # host. Verify that both media byte streams are the same + new_media_mxc_uri = MXCUri( + server_name=self.server_name, media_id=copied_media_id + ) + self.assert_media_is_identical( + knocker_avatar_mxc_uri, new_media_mxc_uri, self.other_tok + ) + + def test_room_creator_and_invitee_with_media_get_copied_and_attached_to_event( + self, + ) -> None: + """Test that member event of room creation copies and attaches the media to the event.""" + creator_mxc_uri = self.create_media_and_set_as_profile_avatar( + self.user, self.tok + ) + invitee_mxc_uri = self.create_media_and_set_as_profile_avatar( + self.other_user, self.other_tok + ) + + # Send the createRoom request + channel = self.make_request( + "POST", + "/_matrix/client/v3/createRoom", + access_token=self.tok, + content={"name": "new room", "invite": [self.other_user]}, + ) + assert channel.code == HTTPStatus.OK, channel.result["body"] + room_id = channel.json_body["room_id"] + + # Get member events of the room creation + creator_event_ids = self.get_success( + self.store.get_membership_event_ids_for_user(self.user, room_id) + ) + assert len(creator_event_ids) == 1 + invitee_event_ids = self.get_success( + self.store.get_membership_event_ids_for_user(self.other_user, room_id) + ) + assert len(invitee_event_ids) == 1 + + # Check if creator's member event has copied the image + creator_member_event_id = creator_event_ids.pop() + assert creator_member_event_id is not None + event = self.get_success(self.store.get_event(creator_member_event_id)) + assert event.type == EventTypes.Member + # It should be different, as the profile was set before the room was made and + # a membership event is created for that room + assert event.content.get(EventContentFields.MEMBERSHIP_AVATAR_URL) != str( + creator_mxc_uri + ) + + # Verify the display name is unchanged + assert ( + event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) + == self.displaynames[self.user] + ) + + # Check if creator's profile avatar is copied and attached to the member event + creators_copied_media_id = self.get_media_id_by_attached_event_id( + creator_member_event_id + ) + + assert creators_copied_media_id is not None + assert creators_copied_media_id != creator_mxc_uri.media_id + + creators_copied_media_restrictions = self.get_success( + self.store.get_media_restrictions( + self.server_name, creators_copied_media_id + ) + ) + assert creators_copied_media_restrictions is not None + assert creators_copied_media_restrictions.event_id == creator_member_event_id + + # the media ID should be different, but the server_name will be for the local + # host. Verify that both media byte streams are the same + creators_new_media_mxc_uri = MXCUri( + server_name=self.server_name, media_id=creators_copied_media_id + ) + self.assert_media_is_identical( + creator_mxc_uri, creators_new_media_mxc_uri, self.tok + ) + + # Check if invitee's member event has copied the image + invitee_member_event_id = invitee_event_ids.pop() + assert invitee_member_event_id is not None + event = self.get_success(self.store.get_event(invitee_member_event_id)) + assert event.type == EventTypes.Member + # It should be different, as the profile was set before the room was made and + # a membership event is created for that room + assert event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) != str( + invitee_mxc_uri + ) + + # Verify the display name is unchanged + assert ( + event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) + == self.displaynames[self.other_user] + ) + + # Check if invitee's profile avatar is copied and attached to the member event + invitees_copied_media_id = self.get_media_id_by_attached_event_id( + invitee_member_event_id + ) + + assert invitees_copied_media_id is not None + assert invitees_copied_media_id != invitee_mxc_uri.media_id + + invitees_copied_media_restrictions = self.get_success( + self.store.get_media_restrictions( + self.server_name, invitees_copied_media_id + ) + ) + assert invitees_copied_media_restrictions is not None + assert invitees_copied_media_restrictions.event_id == invitee_member_event_id + + # Have to verify the invitees mxc uri got copied identically too + invitees_new_media_mxc_uri = MXCUri( + server_name=self.server_name, media_id=invitees_copied_media_id + ) + self.assert_media_is_identical( + invitee_mxc_uri, invitees_new_media_mxc_uri, self.other_tok + ) + + def test_room_upgrader_with_media_get_copied_and_attached_to_event(self) -> None: + """Test that member event of room upgrade copies and attaches the media to the event.""" + mxc_uri = self.create_media_and_set_as_profile_avatar(self.user, self.tok) + room_id = self.helper.create_room_as(self.user, is_public=True, tok=self.tok) + + # Upgrade the room + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{room_id}/upgrade", + access_token=self.tok, + content={"new_version": "10"}, + ) + assert channel.code == HTTPStatus.OK + room_id = channel.json_body["replacement_room"] + + # Get member events of the room upgrade + upgrader_event_ids = self.get_success( + self.store.get_membership_event_ids_for_user(self.user, room_id) + ) + assert len(upgrader_event_ids) == 1 + + creator_member_event_id = upgrader_event_ids.pop() + assert creator_member_event_id is not None + event = self.get_success(self.store.get_event(creator_member_event_id)) + assert event.type == EventTypes.Member + assert event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) != str( + mxc_uri + ) + + # Verify the display name is unchanged + assert ( + event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) + == self.displaynames[self.user] + ) + + # Check if member event has copied the image + copied_media_id = self.get_media_id_by_attached_event_id( + creator_member_event_id + ) + + assert copied_media_id is not None + assert copied_media_id != mxc_uri.media_id + + # Check if the profile avatar is copied and attached to the member event + copied_media_restrictions = self.get_success( + self.store.get_media_restrictions(self.server_name, copied_media_id) + ) + assert copied_media_restrictions is not None + assert copied_media_restrictions.event_id == creator_member_event_id + + # the media ID should be different, but the server_name will be for the local + # host. Verify that both media byte streams are the same + new_media_mxc_uri = MXCUri( + server_name=self.server_name, media_id=copied_media_id + ) + self.assert_media_is_identical(mxc_uri, new_media_mxc_uri, self.other_tok) From ac8750dba109c7409aaa1452ef927b3f49226e95 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Tue, 16 Sep 2025 17:37:32 -0500 Subject: [PATCH 25/35] Swap out the 'get_state_for_events()' method to retrieve state mappings This particular function is not particularly usable when early in a room and applying a filter for events that do not exist yet --- synapse/media/media_repository.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 5eb583d018..a505da6f36 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -339,7 +339,6 @@ async def is_media_visible( if attached_event_id: event_base = await self.store.get_event(attached_event_id) - storage_controllers = self.hs.get_storage_controllers() if event_base.is_state(): # The standard event visibility utility, filter_events_for_client(), # does not seem to meet the needs of a good UX when restricting and @@ -367,17 +366,29 @@ async def is_media_visible( membership_state_key = (EventTypes.Member, requesting_user.to_string()) types = (_HISTORY_VIS_KEY, membership_state_key) + # and history visibility and membership of THEN - event_id_to_state = ( - await storage_controllers.state.get_state_for_events( - [attached_event_id], - state_filter=StateFilter.from_types(types), - ) + state_filter = StateFilter.from_types(types) + state_handler = self.hs.get_state_handler() + + # State Map to Event IDs + state_map_to_e_id = await state_handler.compute_state_after_events( + event_base.room_id, [attached_event_id], state_filter=state_filter + ) + # Get the EventBases for those Event IDs + events = await self.store.get_events( + state_map_to_e_id.values(), ) + # Sort into mapping of StateMap to EventBases + state_map = { + k: events[v] for k, v in state_map_to_e_id.items() if v in events + } - state_map = event_id_to_state.get(attached_event_id) - # Do we need to guard against not having state of a room? - assert state_map is not None + # Don't need to make sure we have an actual StateMap. The defaults + # applied below handle those occasions. E.g. if it is early in a room + # at the point of the event we are trying to get visibility on, the + # state may not exist yet for these filtered events. Like for the + # membership event that follows room creation. visibility = get_effective_room_visibility_from_state(state_map) @@ -443,6 +454,7 @@ async def is_media_visible( return else: + storage_controllers = self.hs.get_storage_controllers() filtered_events = await filter_events_for_client( storage_controllers, requesting_user.to_string(), From 4b4154ffc52ae7fa62f44f20411d91e83e9c66f3 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Wed, 17 Sep 2025 08:23:29 -0500 Subject: [PATCH 26/35] fix: Avoid partial room creation due to errors in room avatars with respect to media restriction --- synapse/handlers/room.py | 61 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7451feca16..0306999284 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1220,6 +1220,61 @@ async def create_room( # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier + # In order to prevent partial room entries in the database, pre-sanitize and + # validate the data. No data should be persisted at this time. + # Currently, only checks the m.room.avatar entry. If it exists. + raw_initial_state = config.get("initial_state", []) + + initial_state = OrderedDict() + for val in raw_initial_state: + initial_state[(val["type"], val.get("state_key", ""))] = val["content"] + + room_avatar = initial_state.get((EventTypes.RoomAvatar, "")) + # It is unfortunate that this conditional block has to be run twice: once here + # to validate the data and again to actually attach the media reference. + if room_avatar is not None and self.config.experimental.msc3911_enabled: + # this should be an mxc, but the spec does not specifically say it has to be + extracted_media_id: Optional[str] = room_avatar.get("url") + # It may be that "url" is set to either an empty string or None. Accept + # this gracefully, to account for backwards compatible behavior + if extracted_media_id: + try: + mxc_uri = MXCUri.from_str(extracted_media_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"Room avatar MXC Uri ('{extracted_media_id}') is malformed", + Codes.INVALID_PARAM, + ) + # If there is a media item, check for existing restrictions + local_media_data = await self.store.get_local_media(mxc_uri.media_id) + + # Non-existent media is to be handled by the download media endpoint + # so ignore it for now and allow proceeding. It will just not attach + # the media + if local_media_data is not None and local_media_data.restricted: + if not local_media_data.attachments: + if requester.user.to_string() != local_media_data.user_id: + # A different user created a room compared to who uploaded + # the media. Just like with the '/state/' and '/send/' + # endpoints, do not leak the metadata + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media requested for a room avatar is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + + else: + # This media is already attached. If it was a prior attempt to + # create a room, the atomic handling of the room creation means + # it will not be attached, so if this exists it succeeded + # somewhere else. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media requested for a room avatar is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + # trusted private chats have the invited users marked as additional creators if ( room_version.msc4289_creator_power_enabled @@ -1281,12 +1336,6 @@ async def create_room( check_membership=False, ) - raw_initial_state = config.get("initial_state", []) - - initial_state = OrderedDict() - for val in raw_initial_state: - initial_state[(val["type"], val.get("state_key", ""))] = val["content"] - ( last_stream_id, last_sent_event_id, From 3cf58b60b63bc37752c040629da229daec64df3a Mon Sep 17 00:00:00 2001 From: Jason Little Date: Wed, 17 Sep 2025 11:27:52 -0500 Subject: [PATCH 27/35] chore: Add some debug logging so media that is not visible has a discoverable reason --- synapse/media/media_repository.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 5eb583d018..0b9c114eef 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -330,6 +330,12 @@ async def is_media_visible( return # It was restricted, but no attachments. Deny + logger.debug( + "Media ID ('%s') as requested by '%s' was restricted but had no " + "attachments", + media_info_object.media_id, + requesting_user.to_string(), + ) raise UnauthorizedRequestAPICallError( f"Media requested ('{media_info_object.media_id}') is restricted" ) @@ -475,6 +481,14 @@ async def is_media_visible( requesting_user.to_string(), [attached_profile_user_id.to_string()], ): + logger.debug( + "Media ID (%s) as requested by '%s' was restricted by " + "profile, but was not allowed(is " + "'limit_profile_requests_to_users_who_share_rooms' enabled?)", + media_info_object.media_id, + requesting_user.to_string(), + ) + raise UnauthorizedRequestAPICallError( f"Media requested ('{media_info_object.media_id}') is restricted" ) @@ -487,6 +501,13 @@ async def is_media_visible( return # It was a third unknown restriction, or otherwise did not pass inspection + logger.debug( + "Media ID (%s) as requested by '%s' was restricted, but was not " + "allowed(media_attachments=%s)", + media_info_object.media_id, + requesting_user.to_string(), + media_info_object.attachments, + ) raise UnauthorizedRequestAPICallError( f"Media requested ('{media_info_object.media_id}') is restricted" ) From bc9dc1e1f0ed67cbd52a4b4216e10052a86c9cc9 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Mon, 22 Sep 2025 13:05:04 -0500 Subject: [PATCH 28/35] fix: reorganize profile avatar validation to account for remote media --- synapse/handlers/profile.py | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index f8ff2dc947..b1809155b2 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -301,15 +301,36 @@ async def validate_avatar_url(self, avatar_url: str, requester: Requester) -> No avatar_url = f"mxc://{avatar_url}" mxc_uri = MXCUri.from_str(avatar_url) - media_info = await self.hs.get_datastores().main.get_local_media( - mxc_uri.media_id - ) - if media_info is None or media_info.user_id != requester.user.to_string(): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", - Codes.INVALID_PARAM, + media_info: Union[LocalMedia, RemoteMedia, None] + if self._is_mine_server_name(mxc_uri.server_name): + media_info = await self.hs.get_datastores().main.get_local_media( + mxc_uri.media_id + ) + if not media_info: + # Locally, if there is no data on this, it implies they are making up + # things. Or, maybe, that they forgot to do a `/copy` first. Let them + # known + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + + elif media_info.user_id != requester.user.to_string(): + # Media doesn't belong to the requester. Nope + raise SynapseError( + HTTPStatus.BAD_REQUEST, + f"The media attachment request is invalid as the media '{mxc_uri.media_id}' does not exist", + Codes.INVALID_PARAM, + ) + else: + media_info = await self.hs.get_datastores().main.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id ) + if not media_info: + # There is no data on this, not much can be done until it comes along + return + if self.disable_unrestricted_media and not media_info.restricted: raise SynapseError( HTTPStatus.BAD_REQUEST, From e950d0c53e083820a2c9e04f5029ac486c0e00a5 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Tue, 23 Sep 2025 10:29:09 -0500 Subject: [PATCH 29/35] fix(MSC3911): Correct where to find remote media at on the local filesystem when doing a copy --- synapse/media/media_repository.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index f1c0021453..a657f6f9a1 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -869,7 +869,8 @@ async def copy_media( old_media_info = await self.get_media_info(existing_mxc) if isinstance(old_media_info, RemoteMedia): file_info = FileInfo( - server_name=old_media_info.media_origin, file_id=old_media_info.media_id + server_name=old_media_info.media_origin, + file_id=old_media_info.filesystem_id, ) else: file_info = FileInfo(server_name=None, file_id=old_media_info.media_id) From 815ff8f066b295cb381ed2e1bd10ff133c549028 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Mon, 20 Oct 2025 07:50:12 -0500 Subject: [PATCH 30/35] fix(MSC3911): Handle profile avatars that are unknown gracefully --- synapse/handlers/profile.py | 16 +++++- synapse/handlers/room_member.py | 83 +++++++++++++++++++++++++++---- synapse/media/media_repository.py | 8 ++- tests/rest/client/test_profile.py | 38 +++++++++++++- tests/rest/client/test_rooms.py | 51 +++++++++++++++++-- 5 files changed, 180 insertions(+), 16 deletions(-) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b1809155b2..5febb3a552 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -327,8 +327,22 @@ async def validate_avatar_url(self, avatar_url: str, requester: Requester) -> No media_info = await self.hs.get_datastores().main.get_cached_remote_media( mxc_uri.server_name, mxc_uri.media_id ) + # When our local user has attempted to set a profile avatar to a remote + # piece of media, but the local server has not actually seen it, only + # complain if unrestricted media is disabled, otherwise just allow it. Later + # when/if the new profile data is propagated to each room's membership + # event, it will either be copied/passed along/dropped depending on the + # above circumstances if not media_info: - # There is no data on this, not much can be done until it comes along + if self.disable_unrestricted_media: + # The user should have done a COPY on this media previous to this + # attempt to set + raise SynapseError( + HTTPStatus.NOT_FOUND, + "Profile request to update avatar to remote media can not proceed, a /copy request should have happened first", + errcode=Codes.NOT_FOUND, + ) + # For backwards compatible behavior, treat the media as unrestricted return if self.disable_unrestricted_media and not media_info.restricted: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d7cfa4c05b..18c5b0437c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -110,6 +110,9 @@ def __init__(self, hs: "HomeServer"): self.event_auth_handler = hs.get_event_auth_handler() self._worker_lock_handler = hs.get_worker_locks_handler() self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.allow_legacy_media = ( + not hs.config.experimental.msc3911_unrestricted_media_upload_disabled + ) self._membership_types_to_include_profile_data_in = { Membership.JOIN, @@ -863,21 +866,83 @@ async def update_membership_locked( logger.info("Failed to get profile information for %r: %s", target, e) if self.enable_restricted_media and not media_info_for_attachment: - # Other than membership - avatar_url = content.get("avatar_url") + # This code path should only be taken for memberships updating an + # avatar url + avatar_url = content.get(EventContentFields.MEMBERSHIP_AVATAR_URL) if avatar_url: # Something about the MediaRepository does not like being part of # the initialization code of the RoomMemberHandler, so just import # it on the spot instead. media_repo = self.hs.get_media_repository() - new_mxc_uri = await media_repo.copy_media( - MXCUri.from_str(avatar_url), requester.user, 20_000 - ) - media_object = await media_repo.get_media_info(new_mxc_uri) - assert isinstance(media_object, LocalMedia) - media_info_for_attachment = {media_object} - content[EventContentFields.MEMBERSHIP_AVATAR_URL] = str(new_mxc_uri) + try: + # Run a preflight that this media exists. The http replication + # call should not be a thundering herd of guaranteed http fails. + existing_mxc_uri = MXCUri.from_str(avatar_url) + # The actual info is irrelevant at this stage, just ignore. This + # will raise if the media does not exist + _ = await media_repo.get_media_info(existing_mxc_uri) + + new_mxc_uri = await media_repo.copy_media( + existing_mxc_uri, requester.user, 20_000 + ) + except SynapseError: + # The only kind of media copying that should fail is from a + # remote item that has not been seen and cached previously. This + # should already be guarded from in the only circumstances we + # could think of: setting a profile avatar. + # + # If legacy unrestricted media is allowed, this can still be + # triggered. If this happens, ignore it. This allows the avatar + # url to be faithfully set to the given url. All we can do is + # hope that it is not restricted when it finally shows up. + # Guard against other potentials escaping by raising if it + # should occur when unrestricted media begins to be disallowed. + if self.allow_legacy_media: + logger.debug( + "Ignoring media copy request; the media is unknown and " + "will not be treated as restricted" + ) + + else: + # Don't fail the new membership event just because the media + # was not found. Specifically, for: + # * Outgoing remote invites + # * Joins in the context of a new room + # Just remove it, and let the client deal with the lack of + # hint + if effective_membership_state == Membership.INVITE or ( + effective_membership_state == Membership.JOIN + and new_room + ): + logger.warning( + "Unknown media (%s) can not be restricted to " + "membership event for '%s', dropping avatar from " + "event", + avatar_url, + target, + ) + del content[EventContentFields.MEMBERSHIP_AVATAR_URL] + + else: + # All other types of membership should raise. For + # whatever reason, the media is missing so this should + # not continue + logger.warning( + "Unknown media (%s) can not be restricted to " + "membership event for '%s'", + avatar_url, + target, + ) + raise + + else: + media_object = await media_repo.get_media_info(new_mxc_uri) + assert isinstance(media_object, LocalMedia) + media_info_for_attachment = {media_object} + content[EventContentFields.MEMBERSHIP_AVATAR_URL] = str( + new_mxc_uri + ) # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index a657f6f9a1..3acfd48ca3 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -153,9 +153,13 @@ async def get_media_info(self, mxc_uri: MXCUri) -> Union[LocalMedia, RemoteMedia mxc_uri.server_name, mxc_uri.media_id ) if not media_info: - raise SynapseError(404, "Media not found", errcode="M_NOT_FOUND") + raise SynapseError( + HTTPStatus.NOT_FOUND, "Media not found", errcode=Codes.NOT_FOUND + ) if media_info.quarantined_by: - raise SynapseError(404, "Media not found", errcode="M_NOT_FOUND") + raise SynapseError( + HTTPStatus.NOT_FOUND, "Media not found", errcode=Codes.NOT_FOUND + ) return media_info async def create_or_update_content( diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 110d6bdd0f..7bb618bc06 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -1015,7 +1015,7 @@ def test_can_attach_media_to_profile_update(self) -> None: assert restrictions.event_id is None assert restrictions.profile_user_id == UserID.from_string(self.user) - def test_attaching_nonexistent_media_to_profile_fails(self) -> None: + def test_attaching_nonexistent_local_media_to_profile_fails(self) -> None: """ Test that media that does not exist is not allowed to be attached to a user profile. """ @@ -1032,6 +1032,42 @@ def test_attaching_nonexistent_media_to_profile_fails(self) -> None: assert channel.json_body["errcode"] == Codes.INVALID_PARAM assert "does not exist" in channel.json_body["error"] + def test_attaching_unreachable_remote_media_to_profile_might_succeed(self) -> None: + """ + Test that media that can not be retrieved can be attached to a user profile, if + legacy unrestricted media is allowed. + """ + # Generate non-existing media. + nonexistent_mxc_uri = MXCUri.from_str("mxc://remote/fakeMediaId") + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(nonexistent_mxc_uri)}, + ) + + assert channel.code == HTTPStatus.OK, channel.json_body + + @override_config( + {"experimental_features": {"msc3911_unrestricted_media_upload_disabled": True}} + ) + def test_attaching_unreachable_remote_media_to_profile_fails(self) -> None: + """ + Test that media that can not be retrieved will fail to be attached to a user + profile when legacy unrestricted media is disabled. + """ + # Generate non-existing media. + nonexistent_mxc_uri = MXCUri.from_str("mxc://remote/fakeMediaId_2") + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(nonexistent_mxc_uri)}, + ) + + assert channel.code == HTTPStatus.NOT_FOUND, channel.json_body + assert channel.json_body["errcode"] == Codes.NOT_FOUND + def test_attaching_unrestricted_media_to_profile(self) -> None: """ Test that attaching unrestricted media to user profile also works diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index da1afbc94c..3def0eec06 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -5269,6 +5269,53 @@ def test_create_room_fails_with_malformed_room_avatar_url(self) -> None: room_id = self.create_room_with_avatar(avatar_mxc="junk", expected_code=400) assert room_id is None + @override_config( + {"experimental_features": {"msc3911_unrestricted_media_upload_disabled": True}} + ) + def test_create_room_with_missing_profile_avatar_media_succeeds(self) -> None: + """ + Test that a profile avatar that should automatically be included in a room + creator's join event does not break the room when the actual media of the avatar + is missing. + """ + # First inject a profile avatar url directly into the database. The handler + # functions for such can not be used as they do validation, and it would fail as + # the media does not actually exist. + avatar_mxc_uri = MXCUri.from_str("mxc://fake-domain/whatever") + # Make sure to add the restrictions too + self.get_success_or_raise( + self.hs.get_datastores().main.set_media_restricted_to_user_profile( + avatar_mxc_uri.server_name, + avatar_mxc_uri.media_id, + self.user, + ) + ) + self.get_success_or_raise( + self.store.set_profile_avatar_url( + UserID.from_string(self.user), str(avatar_mxc_uri) + ) + ) + + # try and create room. This should succeed, but the avatar will have been + # stripped from the join event of the creator + room_id = self.helper.create_room_as( + self.user, + tok=self.tok, + ) + assert room_id is not None + # Make sure the avatar is not on the event + membership_as_set = self.get_success_or_raise( + self.store.get_membership_event_ids_for_user(self.user, room_id) + ) + join_event_id = membership_as_set.pop() + join_event = self.get_success_or_raise(self.store.get_event(join_event_id)) + assert join_event.content["membership"] == Membership.JOIN + assert "avatar_url" not in join_event.content + + # The display name would have been added, see if that is still there + assert "displayname" in join_event.content + assert join_event.content["displayname"] == "david" + class RoomMemberEventMediaAttachmentTestCase(unittest.HomeserverTestCase): servlets = [ @@ -5489,8 +5536,6 @@ def test_invite_with_media_get_copied_and_attached_to_event(self) -> None: assert channel.code == 200, channel.result["body"] # Get the member event of the invite just occurred. - # creator_event_ids = self.get_success(self.store.get_membership_event_ids_for_user(self.user, room_id)) - # assert len(creator_event_ids) == 1 invitee_event_ids = self.get_success( self.store.get_membership_event_ids_for_user(self.other_user, room_id) ) @@ -5502,7 +5547,7 @@ def test_invite_with_media_get_copied_and_attached_to_event(self) -> None: assert event.type == EventTypes.Member # Verify that this event a different mxc - assert event.content.get(EventContentFields.MEMBERSHIP_DISPLAYNAME) != str( + assert event.content.get(EventContentFields.MEMBERSHIP_AVATAR_URL) != str( invitee_avatar_mxc_uri ) From ca89fbf83d7d9b2b94c3c7bfee4a513e3cd2d961 Mon Sep 17 00:00:00 2001 From: Jason Little Date: Tue, 16 Sep 2025 17:32:38 -0500 Subject: [PATCH 31/35] Setup Complement testing infrastructure in preparation of media workers needing additional configuration --- .ci/scripts/checkout_complement.sh | 2 +- .../conf/workers-shared-extra.yaml.j2 | 2 ++ docker/configure_workers_and_start.py | 18 ++++++++++++++++-- scripts-dev/complement.sh | 1 + 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/.ci/scripts/checkout_complement.sh b/.ci/scripts/checkout_complement.sh index 379f5d4387..7730d8a4b1 100755 --- a/.ci/scripts/checkout_complement.sh +++ b/.ci/scripts/checkout_complement.sh @@ -21,5 +21,5 @@ for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/head continue fi - (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break + (wget -O - "https://github.com/famedly/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break done diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 48b44ddf90..b55eed4040 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -129,6 +129,8 @@ experimental_features: msc3984_appservice_key_query: true # Invite filtering msc4155_enabled: true + # Media Attachment + msc3911_enabled: true server_notices: system_mxid_localpart: _server diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 6f25653bb7..fbe4dc14e2 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -117,7 +117,7 @@ }, "media_repository": { "app": "synapse.app.generic_worker", - "listener_resources": ["media", "client"], + "listener_resources": ["media", "client", "replication"], "endpoint_patterns": [ "^/_matrix/media/", "^/_synapse/admin/v1/purge_media_cache$", @@ -125,7 +125,7 @@ "^/_synapse/admin/v1/user/.*/media.*$", "^/_synapse/admin/v1/media/.*$", "^/_synapse/admin/v1/quarantine_media/.*$", - "^/_matrix/client/v1/media/.*$", + "^/_matrix/client/(v1|unstable/.*)/media/.*$", "^/_matrix/federation/v1/media/.*$", ], # The first configured media worker will run the media background jobs @@ -448,6 +448,9 @@ def add_worker_roles_to_shared_config( if "federation_sender" in worker_types_set: shared_config.setdefault("federation_sender_instances", []).append(worker_name) + if "media_repository" in worker_types_set: + shared_config.setdefault("media_repo_instances", []).append(worker_name) + # Update the list of stream writers. It's convenient that the name of the worker # type is the same as the stream to write. Iterate over the whole list in case there # is more than one. @@ -468,6 +471,17 @@ def add_worker_roles_to_shared_config( "host": "localhost", "port": worker_port, } + if worker == "media_repository": + # Just like for stream_writers, media workers now need to be on the instance_map + if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False): + instance_map[worker_name] = { + "path": f"/run/worker.{worker_port}", + } + else: + instance_map[worker_name] = { + "host": "localhost", + "port": worker_port, + } def merge_worker_template_configs( diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 08b500ecd6..07ad04423d 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -230,6 +230,7 @@ test_packages=( ./tests/msc3967 ./tests/msc4140 ./tests/msc4155 + ./tests/msc3911 ) # Enable dirty runs, so tests will reuse the same container where possible. From f65fb91b4e32c5f9a5858da2d8dc536f877ebc6c Mon Sep 17 00:00:00 2001 From: Jason Little Date: Mon, 3 Nov 2025 07:33:56 -0600 Subject: [PATCH 32/35] Pin installed versions of python to 3.13 for linting and format checking github actions Somehow, PyO3 version 0.24.1 is being used which only has support for Python <= 3.13 even though the version pinned in Cargo.toml is 0.25.1. It is unknown where exactly this version of PyO3 is coming from, but caching oddities are suspected Revert this after Synapse is overall bumped to Python 3.14 in v1.142.0 --- .github/workflows/famedly-tests.yml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/famedly-tests.yml b/.github/workflows/famedly-tests.yml index a328a6f4f6..fac74cdd6d 100644 --- a/.github/workflows/famedly-tests.yml +++ b/.github/workflows/famedly-tests.yml @@ -25,7 +25,7 @@ jobs: - uses: Swatinem/rust-cache@68b3cb7503c78e67dae8373749990a220eb65352 - uses: matrix-org/setup-python-poetry@v2 with: - python-version: "3.x" + python-version: "3.13" poetry-version: "2.1.1" extras: "all" - run: poetry run scripts-dev/generate_sample_config.sh --check @@ -49,7 +49,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: "3.x" + python-version: "3.13" - run: .ci/scripts/check_lockfile.py lint: @@ -63,6 +63,7 @@ jobs: uses: matrix-org/setup-python-poetry@v2 with: poetry-version: "2.1.1" + python-version: "3.13" install-project: "false" - name: Run ruff check @@ -91,6 +92,7 @@ jobs: # https://github.com/matrix-org/synapse/pull/15376#issuecomment-1498983775 # To make CI green, err towards caution and install the project. install-project: "true" + python-version: "3.13" poetry-version: "2.1.1" # Cribbed from @@ -124,6 +126,7 @@ jobs: - uses: matrix-org/setup-python-poetry@v2 with: poetry-version: "2.1.1" + python-version: "3.13" extras: "all" - run: poetry run scripts-dev/check_pydantic_models.py @@ -161,7 +164,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: "3.x" + python-version: "3.13" - run: "pip install rstcheck" - run: "rstcheck --report-level=WARNING README.rst" From 3baaa1eaeb1a0fef089ea02aa9fadfc8855b0bcf Mon Sep 17 00:00:00 2001 From: Jason Little Date: Mon, 27 Oct 2025 07:11:25 -0500 Subject: [PATCH 33/35] chore(msc3911): Some test cleanups --- tests/federation/test_federation_media.py | 142 ++++++++++++++++- tests/replication/test_multi_media_repo.py | 5 - tests/rest/client/test_media.py | 173 ++++----------------- tests/rest/client/test_media_download.py | 22 +-- tests/rest/client/test_media_thumbnail.py | 22 +-- tests/rest/client/test_profile.py | 13 +- tests/storage/test_media.py | 143 +++-------------- 7 files changed, 214 insertions(+), 306 deletions(-) diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index a3e48d89fc..b453b83115 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -22,6 +22,8 @@ import os import shutil import tempfile +from typing import Dict, Optional +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -31,11 +33,14 @@ FileStorageProviderBackend, StorageProviderWrapper, ) -from synapse.rest.client import login +from synapse.rest import admin +from synapse.rest.client import login, media from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main.media_repository import MediaRestrictions from synapse.types import JsonDict, UserID from synapse.util import Clock, json_encoder +from synapse.util.stringutils import random_string from tests import unittest from tests.media.test_media_storage import small_png @@ -191,9 +196,14 @@ def test_federation_etag(self) -> None: class FederationRestrictedMediaDownloadsTest(unittest.FederatingHomeserverTestCase): - servlets = [ - login.register_servlets, - ] + """ + Test that answering a federation download media request behaves appropriately + + More specifically, test that: + * downloads are achieved if restrictions are set + * downloads are blocked if restrictions are not set + * downloads are blocked if restrictions are malformed + """ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) @@ -226,6 +236,10 @@ def default_config(self) -> JsonDict: return config def test_restricted_media_download_with_restrictions_field(self) -> None: + """ + Test that a federation download media request can succeed and is shaped as + expected. + """ content = io.BytesIO(SMALL_PNG) content_uri = self.get_success( self.media_repo.create_or_update_content( @@ -281,6 +295,13 @@ def test_restricted_media_download_with_restrictions_field(self) -> None: self.assertTrue(found_file) def test_restricted_media_download_without_restrictions_field_fails(self) -> None: + """ + Test that restricted media with no restrictions defined is denied over federation + """ + # More specifically, restricted is marked True in the database, but the + # associated table of attachments has no entries. Do not confuse this with the + # potential of restricted being True, but the restrictions being defined but + # empty(as `{}`) content = io.BytesIO(SMALL_PNG) content_uri = self.get_success( self.media_repo.create_or_update_content( @@ -542,3 +563,116 @@ def test_restricted_thumbnail_download_with_restrictions_field(self) -> None: # Check that the png file exists and matches the expected scaled bytes found_file = any(small_png.expected_scaled in field for field in stripped_bytes) self.assertTrue(found_file) + + +class FederationClientDownloadTestCase(unittest.HomeserverTestCase): + """ + Test that an outgoing remote request for federation media is correctly parsed and + inserted into the local database + """ + + test_image = small_png + headers = { + b"Content-Length": [b"%d" % (len(test_image.data))], + b"Content-Type": [test_image.content_type], + b"Content-Disposition": [b"inline"], + } + + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # Mock out the homeserver's MatrixFederationHttpClient + client = Mock() + federation_get_file = AsyncMock() + client.federation_get_file = federation_get_file + self.fed_client_mock = federation_get_file + + hs = self.setup_test_homeserver(federation_http_client=client) + + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.media_repo = hs.get_media_repository() + + self.remote_server = "example.com" + # mapping of media_id -> byte string of the json with the restrictions + self.media_id_data: Dict[str, bytes] = {} + + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + def generate_remote_media_id_and_restrictions( + self, json_dict_response: Optional[JsonDict] = None + ) -> str: + """ + Create a mocked remote media to retrieve + + Args: + json_dict_response: The attachments object that is included in the multipart + response received from federation + """ + media_id = random_string(24) + byte_string = b"{}" + if json_dict_response: + byte_string = json_encoder.encode(json_dict_response).encode() + + self.media_id_data[media_id] = byte_string + return media_id + + def make_request_for_media( + self, json_dict_response: Optional[JsonDict] = None, expected_code: int = 200 + ) -> str: + """ + Place a request to a (mocked) remote server. The request being placed is + actually to the local server, but redirects to the remote to retrieve the media. + This should insert the json part of the response automatically into the database + for us + """ + # Generate media id and restrictions based on SMALL_PNG + media_id = self.generate_remote_media_id_and_restrictions(json_dict_response) + self.fed_client_mock.return_value = ( + 67, + self.headers, + self.media_id_data[media_id], + ) + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.remote_server}/{media_id}", + shorthand=False, + access_token=self.tok, + ) + + self.assertEqual(channel.code, expected_code) + + return media_id + + def test_downloading_remote_media_with_restrictions_is_in_database(self) -> None: + """ + Test that remote media with restrictions correctly is inserted to the database + """ + # Note the unstable prefix is filtered out properly before persistence + media_id = self.make_request_for_media( + {"org.matrix.msc3911.restrictions": {"profile_user_id": "@bob:example.com"}} + ) + restrictions = self.get_success( + self.store.get_media_restrictions(self.remote_server, media_id) + ) + assert isinstance(restrictions, MediaRestrictions) + assert restrictions.profile_user_id is not None + assert restrictions.profile_user_id.to_string() == "@bob:example.com" + + def test_downloading_remote_media_with_no_restrictions_does_not_save_to_db( + self, + ) -> None: + """Test that remote media with no restrictions correctly skips a database entry""" + media_id = self.make_request_for_media() + restrictions = self.get_success( + self.store.get_media_restrictions(self.remote_server, media_id) + ) + assert restrictions is None diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 45bb842520..6d44394eff 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -506,15 +506,11 @@ class CopyRestrictedResourceReplicationTestCase(BaseMultiWorkerStreamTestCase): """ servlets = [ - # media.register_servlets, login.register_servlets, admin.register_servlets, room.register_servlets, ] - # def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - # return self.setup_test_homeserver(config=config) - def default_config(self) -> Dict[str, Any]: config = super().default_config() config.update( @@ -531,7 +527,6 @@ def default_config(self) -> Dict[str, Any]: return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # self.media_repo = hs.get_media_repository() self.profile_handler = self.hs.get_profile_handler() self.user = self.register_user("user", "testpass") self.user_tok = self.login("user", "testpass") diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 78141777b5..f3f5339d4f 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -28,7 +28,7 @@ from contextlib import nullcontext from http import HTTPStatus from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, patch from urllib import parse from urllib.parse import quote, urlencode @@ -68,7 +68,6 @@ from synapse.server import HomeServer from synapse.storage.databases.main.media_repository import ( LocalMedia, - MediaRestrictions, ) from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock, json_encoder @@ -3001,22 +3000,25 @@ class DisableUnrestrictedResourceTestCase(unittest.HomeserverTestCase): limited when `msc3911_unrestricted_media_upload_disabled` is configured to be True. """ - extra_config = { - "experimental_features": {"msc3911_unrestricted_media_upload_disabled": True} - } servlets = [ media.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config.update(self.extra_config) - return self.setup_test_homeserver(config=config) + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update( + {"msc3911_unrestricted_media_upload_disabled": True} + ) + return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() def create_resource_dict(self) -> dict[str, Resource]: + # Important detail: while we are specifically testing these endpoints are + # logically disabled, if we do not load them then this would test the wrong + # thing. resources = super().create_resource_dict() resources["/_matrix/media"] = self.hs.get_media_repository_resource() return resources @@ -3061,19 +3063,17 @@ class RestrictedResourceUploadTestCase(unittest.HomeserverTestCase): configured to be True. """ - extra_config = { - "experimental_features": {"msc3911_enabled": True}, - } servlets = [ media.register_servlets, login.register_servlets, admin.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config.update(self.extra_config) - return self.setup_test_homeserver(config=config) + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() @@ -3084,6 +3084,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.other_user_tok = self.login("random_user", "testpass") def create_resource_dict(self) -> dict[str, Resource]: + # Need this for test_async_upload_restricted_resource() resources = super().create_resource_dict() resources["/_matrix/media"] = self.hs.get_media_repository_resource() return resources @@ -3103,12 +3104,12 @@ def test_create_restricted_resource(self) -> None: media_id = channel.json_body["content_uri"].split("/")[-1] # Check the `restricted` field is True. - media = self.get_success( + local_media = self.get_success( self.hs.get_datastores().main.get_local_media(media_id) ) - assert media is not None - self.assertEqual(media.media_id, media_id) - self.assertTrue(media.restricted) + assert local_media is not None + self.assertEqual(local_media.media_id, media_id) + self.assertTrue(local_media.restricted) def test_upload_restricted_resource(self) -> None: """ @@ -3127,12 +3128,12 @@ def test_upload_restricted_resource(self) -> None: media_id = channel.json_body["content_uri"].split("/")[-1] # Check the `restricted` field is True. - media = self.get_success( + local_media = self.get_success( self.hs.get_datastores().main.get_local_media(media_id) ) - assert media is not None - self.assertEqual(media.media_id, media_id) - self.assertTrue(media.restricted) + assert local_media is not None + self.assertEqual(local_media.media_id, media_id) + self.assertTrue(local_media.restricted) # The media is not attached to any event yet, only creator can see it. # The creator can download the restricted resource. @@ -3179,12 +3180,12 @@ def test_async_upload_restricted_resource(self) -> None: self.assertEqual(channel.code, 200) # Check the `restricted` field is True. - media = self.get_success( + local_media = self.get_success( self.hs.get_datastores().main.get_local_media(media_id) ) - assert media is not None - self.assertEqual(media.media_id, media_id) - self.assertTrue(media.restricted) + assert local_media is not None + self.assertEqual(local_media.media_id, media_id) + self.assertTrue(local_media.restricted) # Media is not attached to any event yet, only creator can see it. # The creator can download the restricted resource. @@ -3209,10 +3210,6 @@ class CopyRestrictedResourceTestCase(unittest.HomeserverTestCase): Tests copy API when `msc3911_enabled` is configured to be True. """ - extra_config = { - "experimental_features": {"msc3911_enabled": True}, - } - servlets = [ media.register_servlets, login.register_servlets, @@ -3220,10 +3217,11 @@ class CopyRestrictedResourceTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config.update(self.extra_config) - return self.setup_test_homeserver(config=config) + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}) + config["experimental_features"].update({"msc3911_enabled": True}) + return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository() @@ -3233,11 +3231,6 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.other_user = self.register_user("other", "testpass") self.other_user_tok = self.login("other", "testpass") - def create_resource_dict(self) -> dict[str, Resource]: - resources = super().create_resource_dict() - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - def fetch_media( self, mxc_uri: MXCUri, @@ -3622,7 +3615,6 @@ class RestrictedMediaVisibilityTestCase(unittest.HomeserverTestCase): login.register_servlets, media.register_servlets, room.register_servlets, - room.register_deprecated_servlets, ] def default_config(self) -> JsonDict: @@ -3642,12 +3634,6 @@ def prepare( self.alice_user_id = self.register_user("alice", "password") self.alice_tok = self.login("alice", "password") - def create_resource_dict(self) -> Dict[str, Resource]: - resources = super().create_resource_dict() - # The old endpoints are not loaded with the register_servlets above - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - def create_restricted_media(self, user: Optional[str] = None) -> MXCUri: mxc_uri = self.get_success( self.media_repo.create_or_update_content( @@ -4505,97 +4491,6 @@ def test_global_profile_is_not_visible_when_not_sharing_a_room_setting_is_enable self.assert_expected_result(profile_viewing_user_id, media_object, False) -class FederationClientDownloadTestCase(unittest.HomeserverTestCase): - test_image = small_png - headers = { - b"Content-Length": [b"%d" % (len(test_image.data))], - b"Content-Type": [test_image.content_type], - b"Content-Disposition": [b"inline"], - } - - servlets = [ - media.register_servlets, - login.register_servlets, - admin.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - # Mock out the homeserver's MatrixFederationHttpClient - client = Mock() - federation_get_file = AsyncMock() - client.federation_get_file = federation_get_file - self.fed_client_mock = federation_get_file - - hs = self.setup_test_homeserver(federation_http_client=client) - - return hs - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.media_repo = hs.get_media_repository() - - self.remote_server = "example.com" - # mapping of media_id -> byte string of the json with the restrictions - self.media_id_data: Dict[str, bytes] = {} - - self.user = self.register_user("user", "pass") - self.tok = self.login("user", "pass") - - def generate_remote_media_id_and_restrictions( - self, json_dict_response: Optional[JsonDict] = None - ) -> str: - media_id = random_string(24) - byte_string = b"{}" - if json_dict_response: - byte_string = json_encoder.encode(json_dict_response).encode() - - self.media_id_data[media_id] = byte_string - return media_id - - def make_request_for_media( - self, json_dict_response: Optional[JsonDict] = None, expected_code: int = 200 - ) -> str: - # Generate media id and restrictions based on SMALL_PNG - media_id = self.generate_remote_media_id_and_restrictions(json_dict_response) - self.fed_client_mock.return_value = ( - 67, - self.headers, - self.media_id_data[media_id], - ) - - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/download/{self.remote_server}/{media_id}", - shorthand=False, - access_token=self.tok, - ) - - self.assertEqual(channel.code, expected_code) - - return media_id - - def test_downloading_remote_media_with_restrictions_is_in_database(self) -> None: - # Note the unstable prefix is filtered out properly before persistence - media_id = self.make_request_for_media( - {"org.matrix.msc3911.restrictions": {"profile_user_id": "@bob:example.com"}} - ) - restrictions = self.get_success( - self.store.get_media_restrictions(self.remote_server, media_id) - ) - assert isinstance(restrictions, MediaRestrictions) - assert restrictions.profile_user_id is not None - assert restrictions.profile_user_id.to_string() == "@bob:example.com" - - def test_downloading_remote_media_with_no_restrictions_does_not_save_to_db( - self, - ) -> None: - media_id = self.make_request_for_media() - restrictions = self.get_success( - self.store.get_media_restrictions(self.remote_server, media_id) - ) - assert restrictions is None - - configs_2 = [ {"enable_restricted_media": True}, {"enable_restricted_media": False}, diff --git a/tests/rest/client/test_media_download.py b/tests/rest/client/test_media_download.py index ece4b32f2d..b755f6ded8 100644 --- a/tests/rest/client/test_media_download.py +++ b/tests/rest/client/test_media_download.py @@ -22,7 +22,6 @@ from matrix_common.types.mxc_uri import MXCUri from twisted.test.proto_helpers import MemoryReactor -from twisted.web.resource import Resource from synapse.api.constants import ( EventContentFields, @@ -74,11 +73,6 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) self.other_profile_test_user_tok = self.login("profile_test_user", "testpass") - def create_resource_dict(self) -> dict[str, Resource]: - resources = super().create_resource_dict() - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - def _create_restricted_media(self, user: str) -> MXCUri: mxc_uri = self.get_success( self.repo.create_or_update_content( @@ -109,7 +103,7 @@ def fetch_media( ) assert channel.code == expected_code, channel.code - def test_user_download_local_media_unrestricted(self) -> None: + def test_local_media_download_unrestricted(self) -> None: """Test that unrestricted media is not affected""" mxc_uri = self.get_success( self.repo.create_or_update_content( @@ -125,7 +119,7 @@ def test_user_download_local_media_unrestricted(self) -> None: self.fetch_media(mxc_uri) self.fetch_media(mxc_uri, access_token=self.other_user_tok) - def test_download_local_media_restricted_but_pending_state(self) -> None: + def test_local_media_download_restricted_but_pending_state(self) -> None: """Test originating user can access media even though it is not attached""" mxc_uri = self._create_restricted_media(self.creator) # The creator user can see their own media @@ -133,7 +127,7 @@ def test_download_local_media_restricted_but_pending_state(self) -> None: # But another user can not self.fetch_media(mxc_uri, access_token=self.other_user_tok, expected_code=403) - def test_user_download_local_media_attached_to_user_profile_success(self) -> None: + def test_local_media_download_attached_to_user_profile_success(self) -> None: """Test retrieving media attached to user's profile""" prime_mxc_uri = self._create_restricted_media(self.creator) other_mxc_uri = self._create_restricted_media(self.other_profile_test_user) @@ -166,7 +160,7 @@ def test_user_download_local_media_attached_to_user_profile_success(self) -> Non "limit_profile_requests_to_users_who_share_rooms": True, } ) - def test_user_download_local_media_attached_to_user_profile_failure(self) -> None: + def test_local_media_download_attached_to_user_profile_failure(self) -> None: """ Test that limiting profile requests works as expected. Specifically, that users that are not sharing a room can not see profile avatars @@ -204,7 +198,7 @@ def test_user_download_local_media_attached_to_user_profile_failure(self) -> Non expected_code=403, ) - def test_user_download_local_media_attached_to_message_event_success(self) -> None: + def test_local_media_download_attached_to_message_event_success(self) -> None: """Test that can local media attached to image event can be viewed""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) @@ -240,7 +234,7 @@ def test_user_download_local_media_attached_to_message_event_success(self) -> No self.fetch_media(mxc_uri) self.fetch_media(mxc_uri, access_token=self.other_user_tok) - def test_user_download_local_media_attached_to_message_event_failure(self) -> None: + def test_local_media_download_attached_to_message_event_failure(self) -> None: """Test that can local media attached to image event can be restricted""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) @@ -277,7 +271,7 @@ def test_user_download_local_media_attached_to_message_event_failure(self) -> No # should fail. self.fetch_media(mxc_uri, access_token=self.other_user_tok, expected_code=403) - def test_user_download_local_media_attached_to_state_event_success(self) -> None: + def test_local_media_download_attached_to_state_event_success(self) -> None: """Test that a simple membership avatar is viewable when appropriate""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) @@ -311,7 +305,7 @@ def test_user_download_local_media_attached_to_state_event_success(self) -> None self.fetch_media(mxc_uri) self.fetch_media(mxc_uri, access_token=self.other_user_tok) - def test_user_download_local_media_attached_to_state_event_failure(self) -> None: + def test_local_media_download_attached_to_state_event_failure(self) -> None: """Test that a simple membership avatar is restricted when appropriate""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) diff --git a/tests/rest/client/test_media_thumbnail.py b/tests/rest/client/test_media_thumbnail.py index 587babfeee..dd65c1e868 100644 --- a/tests/rest/client/test_media_thumbnail.py +++ b/tests/rest/client/test_media_thumbnail.py @@ -22,7 +22,6 @@ from matrix_common.types.mxc_uri import MXCUri from twisted.internet.testing import MemoryReactor -from twisted.web.resource import Resource from synapse.api.constants import ( EventContentFields, @@ -81,11 +80,6 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) self.other_profile_test_user_tok = self.login("profile_test_user", "testpass") - def create_resource_dict(self) -> dict[str, Resource]: - resources = super().create_resource_dict() - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - def _create_restricted_media(self, user: str) -> MXCUri: """ Insert our media directly into the database/repo. This creates the necessary @@ -123,7 +117,7 @@ def fetch_thumbnail( assert channel.code == expect_code, channel.code return channel - def test_user_download_local_media_thumbnail_unrestricted(self) -> None: + def test_local_media_thumbnail_unrestricted(self) -> None: """Test that unrestricted media is not affected""" # Note that 'restricted' is marked as 'False' here content_mxc_uri = self.get_success( @@ -140,7 +134,7 @@ def test_user_download_local_media_thumbnail_unrestricted(self) -> None: self.fetch_thumbnail(content_mxc_uri) self.fetch_thumbnail(content_mxc_uri, access_token=self.other_user_tok) - def test_download_local_media_restricted_but_pending_state(self) -> None: + def test_local_media_thumbnail_restricted_but_pending_state(self) -> None: """Test originating user can access media even though it is not attached""" mxc_uri = self._create_restricted_media(self.creator) @@ -149,7 +143,7 @@ def test_download_local_media_restricted_but_pending_state(self) -> None: # But another user can not self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok, expect_code=403) - def test_user_download_local_media_attached_to_user_profile_success(self) -> None: + def test_local_media_thumbnail_attached_to_user_profile_success(self) -> None: """Test retrieving media attached to user's profile""" prime_mxc_uri = self._create_restricted_media(self.creator) other_mxc_uri = self._create_restricted_media(self.other_profile_test_user) @@ -186,7 +180,7 @@ def test_user_download_local_media_attached_to_user_profile_success(self) -> Non "limit_profile_requests_to_users_who_share_rooms": True, } ) - def test_user_download_local_media_attached_to_user_profile_failure(self) -> None: + def test_local_media_thumbnail_attached_to_user_profile_failure(self) -> None: """ Test that limiting profile requests works as expected. Specifically, that users that are not sharing a room can not see profile avatars @@ -226,7 +220,7 @@ def test_user_download_local_media_attached_to_user_profile_failure(self) -> Non expect_code=403, ) - def test_users_download_local_media_attached_to_message_event_success(self) -> None: + def test_local_media_thumbnail_attached_to_message_event_success(self) -> None: """Test that can local media attached to image event can be viewed""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) @@ -262,7 +256,7 @@ def test_users_download_local_media_attached_to_message_event_success(self) -> N self.fetch_thumbnail(mxc_uri) self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok) - def test_users_download_local_media_attached_to_message_event_failure(self) -> None: + def test_local_media_thumbnail_attached_to_message_event_failure(self) -> None: """Test that can local media attached to image event can be restricted""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) @@ -299,7 +293,7 @@ def test_users_download_local_media_attached_to_message_event_failure(self) -> N # should fail. self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok, expect_code=403) - def test_user_download_local_media_attached_to_state_event_success(self) -> None: + def test_local_media_thumbnail_attached_to_state_event_success(self) -> None: """Test that a simple membership avatar is viewable when appropriate""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) @@ -333,7 +327,7 @@ def test_user_download_local_media_attached_to_state_event_success(self) -> None self.fetch_thumbnail(mxc_uri) self.fetch_thumbnail(mxc_uri, access_token=self.other_user_tok) - def test_user_download_local_media_attached_to_state_event_failure(self) -> None: + def test_local_media_thumbnail_attached_to_state_event_failure(self) -> None: """Test that a simple membership avatar is restricted when appropriate""" mxc_uri = self._create_restricted_media(self.creator) room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 7bb618bc06..52a48c22f5 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -1103,7 +1103,7 @@ def test_attaching_unrestricted_media_to_profile(self) -> None: ) def test_attaching_unrestricted_media_to_profile_fails(self) -> None: """ - Test that attaching unrestricted media to user profile fails when unrestircted + Test that attaching unrestricted media to user profile fails when unrestricted media is banned by configuration. """ # Create unrestricted media. @@ -1308,7 +1308,6 @@ class ProfileMediaAttachmentReplicationTestCase(BaseMultiWorkerStreamTestCase): admin.register_servlets, login.register_servlets, media.register_servlets, - # profile.register_servlets, room.register_servlets, ] @@ -1329,15 +1328,8 @@ def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) config["experimental_features"].update({"msc3911_enabled": True}) - # config["media_repo_instances"] = [MAIN_PROCESS_INSTANCE_NAME] - return config - def create_resource_dict(self) -> dict[str, Resource]: - resources = super().create_resource_dict() - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - def create_media_and_set_restricted_flag(self, user_id: str) -> MXCUri: """ Create media without using an endpoint, and set the restricted flag. @@ -1360,9 +1352,6 @@ def make_worker_hs( ) -> HomeServer: worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) # Mount the room resource onto the worker. - # worker_hs.get_media_repository_resource().register_servlets( - # self._hs_to_site[worker_hs].resource, worker_hs - # ) room.register_servlets(worker_hs, self._hs_to_site[worker_hs].resource) profile.register_servlets(worker_hs, self._hs_to_site[worker_hs].resource) return worker_hs diff --git a/tests/storage/test_media.py b/tests/storage/test_media.py index d6b9219e3a..6bac81c1a6 100644 --- a/tests/storage/test_media.py +++ b/tests/storage/test_media.py @@ -1,26 +1,22 @@ -import io -from typing import Dict - -from matrix_common.types.mxc_uri import MXCUri - from twisted.test.proto_helpers import MemoryReactor -from twisted.web.resource import Resource from synapse.api.errors import SynapseError -from synapse.rest import admin -from synapse.rest.client import login, media from synapse.server import HomeServer -from synapse.storage.databases.main.media_repository import MediaRestrictions -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import SMALL_PNG class MediaAttachmentStorageTestCase(unittest.HomeserverTestCase): - """Test that storing and retrieving media restrictions works as expected""" + """ + Test that storing and retrieving media restrictions works as expected + + Specifically, we test that storing media restrictions are then retrievable, that + our MediaRestrictions object is created as expected, and that a given piece of media + can not be set twice(no overwriting of values) + """ def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer @@ -61,6 +57,7 @@ def test_store_and_retrieve_media_restrictions_by_profile_user_id(self) -> None: assert retrieved_restrictions.profile_user_id == user_id def test_retrieve_media_without_restrictions(self) -> None: + """Test that retrieving non-existent restrictions does not raise an exception""" media_id = random_string(24) retrieved_restrictions = self.get_success_or_raise( @@ -68,45 +65,13 @@ def test_retrieve_media_without_restrictions(self) -> None: ) assert retrieved_restrictions is None - -class MediaPendingAttachmentTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - media.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - return config - - def prepare( - self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer - ) -> None: - self.store = homeserver.get_datastores().main - self.server_name = self.hs.config.server.server_name - - self.user = self.register_user("frank", "password") - self.tok = self.login("frank", "password") - - def create_resource_dict(self) -> Dict[str, Resource]: - resources = super().create_resource_dict() - # The old endpoints are not loaded with the register_servlets above - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - def test_setting_media_restriction_twice_errors( self, ) -> None: """Setting media restrictions on a single piece of media TWICE is not allowed. Test that it errors """ - upload_result = self.helper.upload_media(SMALL_PNG, tok=self.tok) - assert upload_result.get("content_uri") is not None - - content_uri: str = upload_result["content_uri"] - # We can split the content_uri on the last "/" and the rest is the media_id - media_id = content_uri.rsplit("/", maxsplit=1)[1] + media_id = random_string(24) event_id = "$something_hashy_doesnt_matter" self.get_success( @@ -122,87 +87,29 @@ def test_setting_media_restriction_twice_errors( ) ) assert existing_media_restrictions is not None + assert existing_media_restrictions.profile_user_id is None + assert existing_media_restrictions.event_id == event_id + new_event_id = "$something_newer_but_still_hashy" self.get_failure( self.store.set_media_restricted_to_event_id( - self.server_name, media_id, event_id + self.server_name, media_id, new_event_id ), SynapseError, ) - -class MediaAttachmentFlowTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - media.register_servlets, - ] - - def prepare( - self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer - ) -> None: - self.store = homeserver.get_datastores().main - self.server_name = self.hs.config.server.server_name - self.media_repo = self.hs.get_media_repository() - - self.user = self.register_user("frank", "password") - self.tok = self.login("frank", "password") - - def create_resource_dict(self) -> Dict[str, Resource]: - resources = super().create_resource_dict() - # The old endpoints are not loaded with the register_servlets above - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - - def create_media(self) -> MXCUri: - content = io.BytesIO(SMALL_PNG) - content_uri = self.get_success( - self.media_repo.create_or_update_content( - "image/png", - "test_png_upload", - content, - 67, - UserID.from_string("@user_id:whatever.org"), - restricted=True, - ) - ) - return content_uri - - def test_flow(self) -> None: - """Example flow of storing media data and retrieving it from the database""" - # Create media by using create_or_update_content() helper. This will likely be - # on the new `/create` and `/upload` endpoints for msc3911. - - # set actual restrictions using storage methods - # `set_media_restricted_to_event_id()` or `set_media_restricted_to_user_profile()` - - # use `get_local_media()` to retrieve the data - - mxc_uri = self.create_media() - media_id = mxc_uri.media_id - assert media_id - - local_media_object = self.get_success(self.store.get_local_media(media_id)) - assert local_media_object - assert local_media_object.restricted is True - - # This one is why we are here, it doesn't exist yet - assert local_media_object.attachments is None - - event_id = "$event_id_hash_goes_here" - self.get_success( - self.store.set_media_restricted_to_event_id( + # Verify that even with the error, nothing has actually changed + verify_media_restrictions = self.get_success( + self.store.get_media_restrictions( self.server_name, media_id, - event_id, ) ) - - # Retrieve the data and make sure the restrictions are there - local_media_object = self.get_success(self.store.get_local_media(media_id)) - assert local_media_object - - assert local_media_object.restricted is True - # This one is why we are here, it's here this time. Yay! - assert isinstance(local_media_object.attachments, MediaRestrictions) - assert local_media_object.attachments.event_id == event_id + assert verify_media_restrictions is not None + assert ( + verify_media_restrictions.profile_user_id + == existing_media_restrictions.profile_user_id + ) + assert ( + verify_media_restrictions.event_id == existing_media_restrictions.event_id + ) From 3cd87fbff83c747534d9b140c62c51d3f973ada5 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Thu, 18 Sep 2025 11:33:33 +0200 Subject: [PATCH 34/35] feat: automatic pending media deletion --- synapse/config/experimental.py | 7 +- synapse/media/media_repository.py | 11 + .../databases/main/media_repository.py | 23 ++ tests/media/test_pending_media_deletion.py | 239 ++++++++++++++++++ 4 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 tests/media/test_pending_media_deletion.py diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 6af05960f5..c5f54995f3 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -592,7 +592,12 @@ def read_config( # MSC3911: Linking Media to Events self.msc3911_enabled: bool = experimental.get("msc3911_enabled", False) - # Disable the current media create and upload endpoints + # MSC3911: Disable the current media create and upload endpoints self.msc3911_unrestricted_media_upload_disabled: bool = experimental.get( "msc3911_unrestricted_media_upload_disabled", False ) + + # MSC3911: Delete pending media that is older than 24 hours but not attached to any events + self.msc3911_enabled_media_retention = experimental.get( + "msc3911_enabled_media_retention", False + ) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 3acfd48ca3..455c429e64 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -101,6 +101,7 @@ # How often to run the background job to check for local and remote media # that should be purged according to the configured media retention settings. MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour +PENDING_MEDIA_CLEANUP_INTERVAL_MS = 60 * 60 * 1000 # 1 hour class AbstractMediaRepository: @@ -643,6 +644,11 @@ def __init__(self, hs: "HomeServer"): key=lambda limit: limit.time_period_ms, reverse=True ) + if self.hs.config.experimental.msc3911_enabled_media_retention: + self.clock.looping_call( + self._pending_media_cleanup, PENDING_MEDIA_CLEANUP_INTERVAL_MS + ) + def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( "update_recently_accessed_media", self._update_recently_accessed @@ -664,6 +670,11 @@ async def _update_recently_accessed(self) -> None: local_media, remote_media, self.clock.time_msec() ) + async def _pending_media_cleanup(self) -> None: + pending_media_ids = await self.store.get_pending_media_ids() + if pending_media_ids: + await self.delete_local_media_ids(pending_media_ids) + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: """Mark the given media as recently accessed. diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 18ad97959e..9b8cbe22cd 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -638,6 +638,29 @@ def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]: "get_pending_media", get_pending_media_txn ) + async def get_pending_media_ids(self) -> list[str]: + """ + Get a list of ids of pending media that is older than 24 hours and unattached. + """ + threshold_ts = self._clock.time_msec() - 24 * 60 * 60 * 1000 + + def _get_pending_media_ids_txn(txn: LoggingTransaction) -> list[str]: + sql = """ + SELECT local_media_repository.media_id + FROM local_media_repository + LEFT JOIN media_attachments + ON local_media_repository.media_id = media_attachments.media_id + WHERE local_media_repository.restricted IS TRUE + AND media_attachments.restrictions_json IS NULL + AND local_media_repository.created_ts < ?; + """ + txn.execute(sql, (threshold_ts,)) + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_pending_media_ids", _get_pending_media_ids_txn + ) + async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: diff --git a/tests/media/test_pending_media_deletion.py b/tests/media/test_pending_media_deletion.py new file mode 100644 index 0000000000..fb816227a1 --- /dev/null +++ b/tests/media/test_pending_media_deletion.py @@ -0,0 +1,239 @@ +import os +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource + +from synapse.media.filepath import MediaFilePaths +from synapse.media.media_repository import MediaRepository +from synapse.rest import admin +from synapse.rest.client import login, media, profile, room +from synapse.server import HomeServer +from synapse.types import UserID +from synapse.util import Clock +from synapse.util.stringutils import ( + random_string, +) + +from tests import unittest + + +class PendingMediaDeletionTestCase(unittest.HomeserverTestCase): + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config.update( + { + "experimental_features": { + "msc3911_enabled": True, + "msc3911_enabled_media_retention": True, + }, + } + ) + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.media_repository = hs.get_media_repository() + self.store = hs.get_datastores().main + + self.user = self.register_user("user", "testpass") + self.tok = self.login("user", "testpass") + + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) + + def create_resource_dict(self) -> dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def test_pending_media_deletion_success(self) -> None: + """ + Test that media that is older than 24 hours yet not attached to any event or profile is deleted. + """ + assert isinstance(self.media_repository, MediaRepository) + + # Create 2 media that is restricted but not attached to any event or profile + random_content = bytes(random_string(24), "utf-8") + channel = self.make_request( + "POST", + "_matrix/client/unstable/org.matrix.msc3911/media/upload?filename=test_1", + random_content, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(24))], + ) + assert channel.code == 200, channel.json_body + mxc_uri_str = channel.json_body.get("content_uri") + assert mxc_uri_str is not None + media_1_id = mxc_uri_str.rsplit("/", 1)[-1] + + random_content = bytes(random_string(24), "utf-8") + channel = self.make_request( + "POST", + "_matrix/client/unstable/org.matrix.msc3911/media/upload?filename=test_2", + random_content, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(24))], + ) + assert channel.code == 200, channel.json_body + mxc_uri_str = channel.json_body.get("content_uri") + assert mxc_uri_str is not None + media_2_id = mxc_uri_str.rsplit("/", 1)[-1] + + # Prove that the media are written on the local media table + uploaded_media = self.get_success( + self.media_repository.store.get_local_media(media_1_id) + ) + assert uploaded_media is not None + assert uploaded_media.attachments is None + + uploaded_media = self.get_success( + self.media_repository.store.get_local_media(media_2_id) + ) + assert uploaded_media is not None + assert uploaded_media.attachments is None + + # Check if the file exists + local_path_1 = self.filepaths.local_media_filepath(media_1_id) + assert os.path.exists(local_path_1) + local_path_2 = self.filepaths.local_media_filepath(media_2_id) + assert os.path.exists(local_path_2) + + # Advance 25 hours to make the media eligible for deletion + self.reactor.advance(25 * 60 * 60) + + # Check the deletion is completed + uploaded_media = self.get_success( + self.media_repository.store.get_local_media(media_1_id) + ) + assert uploaded_media is None + + uploaded_media = self.get_success( + self.media_repository.store.get_local_media(media_2_id) + ) + assert uploaded_media is None + + # Attempt to access media to check if media is deleted from file + server_name = self.hs.hostname + channel = self.make_request( + "GET", + f"/_matrix/media/v3/download/{server_name}/{media_1_id}", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 404, ( + "Expected to receive a 404 on accessing deleted media: %s:%s" + % (server_name, media_1_id) + ) + + channel = self.make_request( + "GET", + f"/_matrix/media/v3/download/{server_name}/{media_2_id}", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 404, ( + "Expected to receive a 404 on accessing deleted media: %s:%s" + % (server_name, media_1_id) + ) + + # Test if the file is deleted + assert os.path.exists(local_path_1) is False + assert os.path.exists(local_path_2) is False + + def test_pending_media_deletion_not_deleted_if_attached(self) -> None: + """ + Test that media that is attached to an event or profile is not deleted. + """ + assert isinstance(self.media_repository, MediaRepository) + + # Create media via upload endpoint + random_content = bytes(random_string(24), "utf-8") + channel = self.make_request( + "POST", + "_matrix/client/unstable/org.matrix.msc3911/media/upload?filename=test_png_upload", + random_content, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(24))], + ) + assert channel.code == 200, channel.json_body + mxc_uri_str = channel.json_body.get("content_uri") + assert mxc_uri_str is not None + + # Attach the media to a profile + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": mxc_uri_str}, + ) + assert channel.code == HTTPStatus.OK + assert channel.json_body == {} + _, media_id = mxc_uri_str.rsplit("/", maxsplit=1) + + # Check if media is updated with restrictions field + restrictions = self.get_success( + self.store.get_media_restrictions(self.hs.hostname, media_id) + ) + assert restrictions is not None, str(restrictions) + assert restrictions.event_id is None + assert restrictions.profile_user_id == UserID.from_string(self.user) + + # Advance 25 hours + self.reactor.advance(25 * 60 * 60) + + # Check that media is not deleted + uploaded_media = self.get_success( + self.media_repository.store.get_local_media(media_id) + ) + assert uploaded_media is not None + + def test_pending_media_deletion_does_not_delete_unrestricted_media(self) -> None: + """ + Test that unrestricted media should not be deleted. + """ + assert isinstance(self.media_repository, MediaRepository) + + # Create unrestricted media via upload endpoint + random_content = bytes(random_string(24), "utf-8") + channel = self.make_request( + "POST", + "_matrix/media/v3/upload?filename=unrestricted", + random_content, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(24))], + ) + assert channel.code == 200, channel.json_body + mxc_uri_str = channel.json_body.get("content_uri") + assert mxc_uri_str is not None + _, media_id = mxc_uri_str.rsplit("/", maxsplit=1) + + # Check if media is not restricted + media_info = self.get_success( + self.media_repository.store.get_local_media(media_id) + ) + assert media_info is not None + assert media_info.restricted is False + + # Advance 25 hours + self.reactor.advance(25 * 60 * 60) + + # Check that media is not deleted + uploaded_media = self.get_success( + self.media_repository.store.get_local_media(media_id) + ) + assert uploaded_media is not None From 595e5edfcade992e91c6e49bc413cc8819550a47 Mon Sep 17 00:00:00 2001 From: Soyoung Kim Date: Mon, 1 Dec 2025 16:56:15 +0100 Subject: [PATCH 35/35] chore: add msc3911 config class --- .../conf/workers-shared-extra.yaml.j2 | 6 +- synapse/config/experimental.py | 29 ++-- .../federation/transport/server/federation.py | 2 +- synapse/handlers/profile.py | 10 +- synapse/handlers/room.py | 4 +- synapse/handlers/room_member.py | 4 +- synapse/media/media_repository.py | 16 ++- synapse/media/thumbnailer.py | 2 +- synapse/rest/client/media.py | 2 +- synapse/rest/client/profile.py | 14 +- synapse/rest/client/room.py | 4 +- synapse/rest/client/versions.py | 4 +- synapse/rest/media/create_resource.py | 6 +- synapse/rest/media/upload_resource.py | 7 +- .../databases/main/media_repository.py | 7 +- tests/federation/test_federation_media.py | 4 +- tests/media/test_pending_media_deletion.py | 21 +-- tests/replication/test_multi_media_repo.py | 4 +- tests/rest/client/test_media.py | 127 ++++++++++++++++-- tests/rest/client/test_media_download.py | 2 +- tests/rest/client/test_media_thumbnail.py | 2 +- tests/rest/client/test_profile.py | 64 +-------- tests/rest/client/test_rooms.py | 55 +------- 23 files changed, 207 insertions(+), 189 deletions(-) diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index b55eed4040..f6e39a1426 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -130,7 +130,11 @@ experimental_features: # Invite filtering msc4155_enabled: true # Media Attachment - msc3911_enabled: true + msc3911: + enabled: true + block_unrestricted_media_upload: false + purge_pending_unattached_media: true + pending_media_cleanup_interval_ms: 86400000 # 1 day server_notices: system_mxid_localpart: _server diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c5f54995f3..04b7e0a3a4 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -365,6 +365,22 @@ class MSC3866Config: require_approval_for_new_accounts: bool = False +@attr.s(auto_attribs=True, frozen=True, slots=True) +class MSC3911Config: + """Configuration for MSC3911 (Linking Media to Events)""" + + # MSC3911 is enabled + enabled: bool = False + + # Disable the current media create and upload endpoints + block_unrestricted_media_upload: bool = False + + # Delete pending media that is older than certain interval and not attached to any events. + purge_pending_unattached_media: bool = False + # This configures how often the cleanup loop run in milliseconds + pending_media_cleanup_interval_ms: int = 24 * 60 * 60 * 1000 # 1 day + + class ExperimentalConfig(Config): """Config section for enabling experimental features""" @@ -590,14 +606,5 @@ def read_config( self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) # MSC3911: Linking Media to Events - self.msc3911_enabled: bool = experimental.get("msc3911_enabled", False) - - # MSC3911: Disable the current media create and upload endpoints - self.msc3911_unrestricted_media_upload_disabled: bool = experimental.get( - "msc3911_unrestricted_media_upload_disabled", False - ) - - # MSC3911: Delete pending media that is older than 24 hours but not attached to any events - self.msc3911_enabled_media_retention = experimental.get( - "msc3911_enabled_media_retention", False - ) + raw_msc3911_config = experimental.get("msc3911", {}) + self.msc3911 = MSC3911Config(**raw_msc3911_config) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 3bd6a9acab..8738fb65e6 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -79,7 +79,7 @@ def __init__( ): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_federation_server() - self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.enable_restricted_media = hs.config.experimental.msc3911.enabled class FederationSendServlet(BaseFederationServerServlet): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 5febb3a552..b32d3d10d0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -80,9 +80,9 @@ def __init__(self, hs: "HomeServer"): self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules - self.enable_restricted_media = hs.config.experimental.msc3911_enabled - self.disable_unrestricted_media = ( - hs.config.experimental.msc3911_unrestricted_media_upload_disabled + self.enable_restricted_media = hs.config.experimental.msc3911.enabled + self.block_unrestricted_media = ( + hs.config.experimental.msc3911.block_unrestricted_media_upload ) async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: @@ -334,7 +334,7 @@ async def validate_avatar_url(self, avatar_url: str, requester: Requester) -> No # event, it will either be copied/passed along/dropped depending on the # above circumstances if not media_info: - if self.disable_unrestricted_media: + if self.block_unrestricted_media: # The user should have done a COPY on this media previous to this # attempt to set raise SynapseError( @@ -345,7 +345,7 @@ async def validate_avatar_url(self, avatar_url: str, requester: Requester) -> No # For backwards compatible behavior, treat the media as unrestricted return - if self.disable_unrestricted_media and not media_info.restricted: + if self.block_unrestricted_media and not media_info.restricted: raise SynapseError( HTTPStatus.BAD_REQUEST, f"The media attachment request is invalid as the media '{mxc_uri.media_id}' is not restricted", diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0306999284..5f595d16eb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1232,7 +1232,7 @@ async def create_room( room_avatar = initial_state.get((EventTypes.RoomAvatar, "")) # It is unfortunate that this conditional block has to be run twice: once here # to validate the data and again to actually attach the media reference. - if room_avatar is not None and self.config.experimental.msc3911_enabled: + if room_avatar is not None and self.config.experimental.msc3911.enabled: # this should be an mxc, but the spec does not specifically say it has to be extracted_media_id: Optional[str] = room_avatar.get("url") # It may be that "url" is set to either an empty string or None. Accept @@ -1550,7 +1550,7 @@ async def create_event( mxc_restrictions = None if ( - self.config.experimental.msc3911_enabled + self.config.experimental.msc3911.enabled and etype == EventTypes.RoomAvatar ): # this should be an mxc, but the spec does not specifically say it has to be diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 18c5b0437c..efc19b9946 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -109,9 +109,9 @@ def __init__(self, hs: "HomeServer"): self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() self._worker_lock_handler = hs.get_worker_locks_handler() - self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.enable_restricted_media = hs.config.experimental.msc3911.enabled self.allow_legacy_media = ( - not hs.config.experimental.msc3911_unrestricted_media_upload_disabled + not hs.config.experimental.msc3911.block_unrestricted_media_upload ) self._membership_types_to_include_profile_data_in = { diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 455c429e64..f965fc0124 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -101,7 +101,6 @@ # How often to run the background job to check for local and remote media # that should be purged according to the configured media retention settings. MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour -PENDING_MEDIA_CLEANUP_INTERVAL_MS = 60 * 60 * 1000 # 1 hour class AbstractMediaRepository: @@ -113,7 +112,7 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.store = hs.get_datastores().main self._is_mine_server_name = hs.is_mine_server_name - self.enable_media_restriction = self.hs.config.experimental.msc3911_enabled + self.enable_media_restriction = self.hs.config.experimental.msc3911.enabled @trace async def create_media_id_without_expiration( @@ -644,9 +643,14 @@ def __init__(self, hs: "HomeServer"): key=lambda limit: limit.time_period_ms, reverse=True ) - if self.hs.config.experimental.msc3911_enabled_media_retention: + if ( + self.hs.config.experimental.msc3911.purge_pending_unattached_media + and hs.config.media.media_instance_running_background_jobs + == hs.config.worker.worker_name + ): self.clock.looping_call( - self._pending_media_cleanup, PENDING_MEDIA_CLEANUP_INTERVAL_MS + self._pending_media_cleanup, + self.hs.config.experimental.msc3911.pending_media_cleanup_interval_ms, ) def _start_update_recently_accessed(self) -> Deferred: @@ -671,7 +675,9 @@ async def _update_recently_accessed(self) -> None: ) async def _pending_media_cleanup(self) -> None: - pending_media_ids = await self.store.get_pending_media_ids() + pending_media_ids = await self.store.get_pending_media_ids( + self.hs.config.experimental.msc3911.pending_media_cleanup_interval_ms + ) if pending_media_ids: await self.delete_local_media_ids(pending_media_ids) diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index 0b2b37d808..4c2bb3961a 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -271,7 +271,7 @@ def __init__( self.media_storage = media_storage self.store = hs.get_datastores().main self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails - self.enable_media_restriction = self.hs.config.experimental.msc3911_enabled + self.enable_media_restriction = self.hs.config.experimental.msc3911.enabled async def respond_local_thumbnail( self, diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 6a69993eb8..d2c491bcd9 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -410,7 +410,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: MediaConfigResource(hs).register(http_server) ThumbnailResource(hs, media_repo, media_repo.media_storage).register(http_server) DownloadResource(hs, media_repo).register(http_server) - if hs.config.experimental.msc3911_enabled: + if hs.config.experimental.msc3911.enabled: CreateResource(hs, media_repo, restricted=True).register(http_server) UploadRestrictedResource(hs, media_repo).register(http_server) CopyResource(hs, media_repo).register(http_server) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 0a33cd305a..be69e9f0a7 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -109,7 +109,7 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.enable_restricted_media = hs.config.experimental.msc3911.enabled async def on_GET( self, request: SynapseRequest, user_id: str, field_name: str @@ -178,7 +178,7 @@ async def on_PUT( content = parse_json_object_from_request(request) try: - new_value = content[field_name] + input_value = content[field_name] except KeyError: raise SynapseError( 400, f"Missing key '{field_name}'", errcode=Codes.MISSING_PARAM @@ -200,24 +200,24 @@ async def on_PUT( ) if field_name == ProfileFields.DISPLAYNAME: await self.profile_handler.set_displayname( - user, requester, new_value, is_admin, propagate=propagate + user, requester, input_value, is_admin, propagate=propagate ) elif field_name == ProfileFields.AVATAR_URL: - if self.enable_restricted_media and new_value: + if self.enable_restricted_media and input_value: current_avatar_url = ( await self.profile_handler.store.get_profile_avatar_url( requester.user ) ) # If new_value is the same as existing one, keep the function idempotent - if current_avatar_url and str(current_avatar_url) == new_value: + if current_avatar_url and str(current_avatar_url) == input_value: return 200, {} await self.profile_handler.set_avatar_url( - user, requester, new_value, is_admin, propagate=propagate + user, requester, input_value, is_admin, propagate=propagate ) else: await self.profile_handler.set_profile_field( - user, requester, field_name, new_value, is_admin + user, requester, field_name, input_value, is_admin ) return 200, {} diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 0e09fd36d8..2b7100eff0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -277,7 +277,7 @@ def __init__(self, hs: "HomeServer"): self._max_event_delay_ms = hs.config.server.max_event_delay_ms self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self.store = hs.get_datastores().main - self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.enable_restricted_media = hs.config.experimental.msc3911.enabled self.server_name = hs.config.server.server_name def register(self, http_server: HttpServer) -> None: @@ -486,7 +486,7 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self._max_event_delay_ms = hs.config.server.max_event_delay_ms self.store = hs.get_datastores().main - self.enable_restricted_media = hs.config.experimental.msc3911_enabled + self.enable_restricted_media = hs.config.experimental.msc3911.enabled self.server_name = hs.config.server.server_name def register(self, http_server: HttpServer) -> None: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 49944843ac..5bedc7d0f4 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -178,9 +178,9 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # MSC4155: Invite filtering "org.matrix.msc4155": self.config.experimental.msc4155_enabled, # MSC3911: Linking Media to Events - "org.matrix.msc3911": self.config.experimental.msc3911_enabled, + "org.matrix.msc3911": self.config.experimental.msc3911.enabled, # MSC3911: Unrestricted Media Upload - "org.matrix.msc3911.unrestricted_media_upload_disabled": self.config.experimental.msc3911_unrestricted_media_upload_disabled, + "org.matrix.msc3911.block_unrestricted_media_upload": self.config.experimental.msc3911.block_unrestricted_media_upload, }, }, ) diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py index 4c1c486583..1638736be5 100644 --- a/synapse/rest/media/create_resource.py +++ b/synapse/rest/media/create_resource.py @@ -57,8 +57,8 @@ def __init__( cfg=hs.config.ratelimiting.rc_media_create, ) # MSC3911: If this is enabled, this endpoint will not allow media creation,which is unrestricted. - self.msc3911_unrestricted_media_upload_disabled = ( - hs.config.experimental.msc3911_unrestricted_media_upload_disabled + self.msc3911_block_unrestricted_media_upload = ( + hs.config.experimental.msc3911.block_unrestricted_media_upload ) self.restricted = restricted @@ -70,7 +70,7 @@ def __init__( self.PATTERNS = [re.compile("/_matrix/media/v1/create")] async def on_POST(self, request: SynapseRequest) -> None: - if not self.restricted and self.msc3911_unrestricted_media_upload_disabled: + if not self.restricted and self.msc3911_block_unrestricted_media_upload: raise SynapseError( 403, "Unrestricted media creation is disabled", diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 97e478cb86..af8dd5ee12 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -53,9 +53,8 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): self._media_repository_callbacks = ( hs.get_module_api_callbacks().media_repository ) - # MSC3911: If this is enabled, this endpoint will not allow unrestricted media uploads. - self.msc3911_unrestricted_media_upload_disabled = ( - hs.config.experimental.msc3911_unrestricted_media_upload_disabled + self.msc3911_block_unrestricted_media_upload = ( + hs.config.experimental.msc3911.block_unrestricted_media_upload ) async def _get_file_metadata( @@ -117,7 +116,7 @@ class UploadServlet(BaseUploadServlet): PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")] async def on_POST(self, request: SynapseRequest) -> None: - if self.msc3911_unrestricted_media_upload_disabled: + if self.msc3911_block_unrestricted_media_upload: raise SynapseError( 403, "Unrestricted media upload is disabled", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 9b8cbe22cd..03783aee17 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -638,11 +638,12 @@ def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]: "get_pending_media", get_pending_media_txn ) - async def get_pending_media_ids(self) -> list[str]: + async def get_pending_media_ids(self, interval: int) -> list[str]: """ - Get a list of ids of pending media that is older than 24 hours and unattached. + Get a list of ids of pending media that is older than the given interval and + unattached. """ - threshold_ts = self._clock.time_msec() - 24 * 60 * 60 * 1000 + threshold_ts = self._clock.time_msec() - interval def _get_pending_media_ids_txn(txn: LoggingTransaction) -> list[str]: sql = """ diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index b453b83115..6a417d5b93 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -232,7 +232,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def test_restricted_media_download_with_restrictions_field(self) -> None: @@ -505,7 +505,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def test_restricted_thumbnail_download_with_restrictions_field(self) -> None: diff --git a/tests/media/test_pending_media_deletion.py b/tests/media/test_pending_media_deletion.py index fb816227a1..bf76bcaaca 100644 --- a/tests/media/test_pending_media_deletion.py +++ b/tests/media/test_pending_media_deletion.py @@ -9,7 +9,7 @@ from synapse.rest import admin from synapse.rest.client import login, media, profile, room from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID from synapse.util import Clock from synapse.util.stringutils import ( random_string, @@ -27,17 +27,17 @@ class PendingMediaDeletionTestCase(unittest.HomeserverTestCase): profile.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config.update( + def default_config(self) -> JsonDict: + config = super().default_config() + config.setdefault("experimental_features", {}).update( { - "experimental_features": { - "msc3911_enabled": True, - "msc3911_enabled_media_retention": True, + "msc3911": { + "enabled": True, + "purge_pending_unattached_media": True, }, - } + }, ) - return self.setup_test_homeserver(config=config) + return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repository = hs.get_media_repository() @@ -55,7 +55,8 @@ def create_resource_dict(self) -> dict[str, Resource]: def test_pending_media_deletion_success(self) -> None: """ - Test that media that is older than 24 hours yet not attached to any event or profile is deleted. + Test that media that is older than given time interval and not attached to any + event or profile is deleted. """ assert isinstance(self.media_repository, MediaRepository) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 6d44394eff..b37452cdd9 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -502,7 +502,7 @@ def _count_remote_thumbnails(self) -> int: class CopyRestrictedResourceReplicationTestCase(BaseMultiWorkerStreamTestCase): """ - Tests copy API when `msc3911_enabled` is configured to be True. + Tests copy API when `msc3911.enabled` is configured to be True. """ servlets = [ @@ -515,7 +515,7 @@ def default_config(self) -> Dict[str, Any]: config = super().default_config() config.update( { - "experimental_features": {"msc3911_enabled": True}, + "experimental_features": {"msc3911": {"enabled": True}}, "media_repo_instances": ["media_worker_1"], } ) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index f3f5339d4f..a1f3e6dcdc 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -64,7 +64,7 @@ from synapse.media.thumbnailer import ThumbnailProvider from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.rest import admin -from synapse.rest.client import login, media, room +from synapse.rest.client import login, media, profile, room from synapse.server import HomeServer from synapse.storage.databases.main.media_repository import ( LocalMedia, @@ -2997,23 +2997,30 @@ def test_over_weekly_limit(self) -> None: class DisableUnrestrictedResourceTestCase(unittest.HomeserverTestCase): """ This test case simulates a homeserver with media create and upload endpoints are - limited when `msc3911_unrestricted_media_upload_disabled` is configured to be True. + limited when `msc3911.block_unrestricted_media_upload` is configured to be True. """ servlets = [ media.register_servlets, + admin.register_servlets, + login.register_servlets, + profile.register_servlets, + room.register_servlets, ] def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) config["experimental_features"].update( - {"msc3911_unrestricted_media_upload_disabled": True} + {"msc3911": {"enabled": True, "block_unrestricted_media_upload": True}} ) return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() + self.media_repo = hs.get_media_repository() + self.store = hs.get_datastores().main + self.user = self.register_user("david", "password") + self.tok = self.login("david", "password") def create_resource_dict(self) -> dict[str, Resource]: # Important detail: while we are specifically testing these endpoints are @@ -3026,7 +3033,7 @@ def create_resource_dict(self) -> dict[str, Resource]: def test_unrestricted_resource_creation_disabled(self) -> None: """ Tests that CreateResource raises an error when - `msc3911_unrestricted_media_upload_disabled` is True. + `msc3911.block_unrestricted_media_upload` is True. """ channel = self.make_request( "POST", @@ -3041,7 +3048,7 @@ def test_unrestricted_resource_creation_disabled(self) -> None: def test_unrestricted_resource_upload_disabled(self) -> None: """ Tests that UploadServlet raises an error when - `msc3911_unrestricted_media_upload_disabled` is True. + `msc3911.block_unrestricted_media_upload` is True. """ channel = self.make_request( "POST", @@ -3056,10 +3063,106 @@ def test_unrestricted_resource_upload_disabled(self) -> None: channel.json_body["error"], "Unrestricted media upload is disabled" ) + def test_attaching_unrestricted_media_to_profile_fails(self) -> None: + """ + Test that attaching unrestricted media to user profile fails when unrestricted + media is banned by configuration. + """ + # Create unrestricted media. + content = io.BytesIO(SMALL_PNG) + assert isinstance(self.media_repo, MediaRepository) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_upload", + content, + 67, + UserID.from_string(self.user), + restricted=False, + ) + ) + + # Check media is unrestricted. + media_info = self.get_success(self.store.get_local_media(content_uri.media_id)) + assert media_info is not None + assert not media_info.restricted + + # Try to update user profile with unrestricted media. + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(content_uri)}, + ) + assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body + assert channel.json_body["errcode"] == Codes.INVALID_PARAM + assert "is not restricted" in channel.json_body["error"] + + def test_attaching_unreachable_remote_media_to_profile_fails(self) -> None: + """ + Test that media that can not be retrieved will fail to be attached to a user + profile when legacy unrestricted media is disabled. + """ + # Generate non-existing media. + nonexistent_mxc_uri = MXCUri.from_str("mxc://remote/fakeMediaId_2") + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user}/avatar_url", + access_token=self.tok, + content={"avatar_url": str(nonexistent_mxc_uri)}, + ) + + assert channel.code == HTTPStatus.NOT_FOUND, channel.json_body + assert channel.json_body["errcode"] == Codes.NOT_FOUND + + def test_create_room_with_missing_profile_avatar_media_succeeds(self) -> None: + """ + Test that a profile avatar that should automatically be included in a room + creator's join event does not break the room when the actual media of the avatar + is missing. + """ + # First inject a profile avatar url directly into the database. The handler + # functions for such can not be used as they do validation, and it would fail as + # the media does not actually exist. + avatar_mxc_uri = MXCUri.from_str("mxc://fake-domain/whatever") + # Make sure to add the restrictions too + self.get_success_or_raise( + self.hs.get_datastores().main.set_media_restricted_to_user_profile( + avatar_mxc_uri.server_name, + avatar_mxc_uri.media_id, + self.user, + ) + ) + self.get_success_or_raise( + self.store.set_profile_avatar_url( + UserID.from_string(self.user), str(avatar_mxc_uri) + ) + ) + + # try and create room. This should succeed, but the avatar will have been + # stripped from the join event of the creator + room_id = self.helper.create_room_as( + self.user, + tok=self.tok, + ) + assert room_id is not None + # Make sure the avatar is not on the event + membership_as_set = self.get_success_or_raise( + self.store.get_membership_event_ids_for_user(self.user, room_id) + ) + join_event_id = membership_as_set.pop() + join_event = self.get_success_or_raise(self.store.get_event(join_event_id)) + assert join_event.content["membership"] == Membership.JOIN + assert "avatar_url" not in join_event.content + + # The display name would have been added, see if that is still there + assert "displayname" in join_event.content + assert join_event.content["displayname"] == "david" + class RestrictedResourceUploadTestCase(unittest.HomeserverTestCase): """ - Tests restricted media creation and upload endpoints when `msc3911_enabled` is + Tests restricted media creation and upload endpoints when `msc3911.enabled` is configured to be True. """ @@ -3072,7 +3175,7 @@ class RestrictedResourceUploadTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -3207,7 +3310,7 @@ def test_async_upload_restricted_resource(self) -> None: class CopyRestrictedResourceTestCase(unittest.HomeserverTestCase): """ - Tests copy API when `msc3911_enabled` is configured to be True. + Tests copy API when `msc3911.enabled` is configured to be True. """ servlets = [ @@ -3220,7 +3323,7 @@ class CopyRestrictedResourceTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -3620,7 +3723,7 @@ class RestrictedMediaVisibilityTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def prepare( @@ -4524,7 +4627,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config["media_store_path"] = self.media_store_path config["enable_authenticated_media"] = True config["experimental_features"] = { - "msc3911_enabled": self.enable_restricted_media + "msc3911": {"enabled": self.enable_restricted_media} } provider_config = { diff --git a/tests/rest/client/test_media_download.py b/tests/rest/client/test_media_download.py index b755f6ded8..e71b4c8acc 100644 --- a/tests/rest/client/test_media_download.py +++ b/tests/rest/client/test_media_download.py @@ -58,7 +58,7 @@ class RestrictedResourceDownloadTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: diff --git a/tests/rest/client/test_media_thumbnail.py b/tests/rest/client/test_media_thumbnail.py index dd65c1e868..1155372ce7 100644 --- a/tests/rest/client/test_media_thumbnail.py +++ b/tests/rest/client/test_media_thumbnail.py @@ -59,7 +59,7 @@ class RestrictedResourceThumbnailTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) # This is what the defaults are for both 'crop' and 'scale' as reference # We don't need to set these, but it's good to know # "thumbnail_sizes": [ diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 52a48c22f5..8776dc919f 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -45,7 +45,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, make_request from tests.test_utils import SMALL_PNG -from tests.unittest import override_config from tests.utils import USE_POSTGRES_FOR_TESTS @@ -945,7 +944,7 @@ def prepare( def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def create_resource_dict(self) -> dict[str, Resource]: @@ -1048,26 +1047,6 @@ def test_attaching_unreachable_remote_media_to_profile_might_succeed(self) -> No assert channel.code == HTTPStatus.OK, channel.json_body - @override_config( - {"experimental_features": {"msc3911_unrestricted_media_upload_disabled": True}} - ) - def test_attaching_unreachable_remote_media_to_profile_fails(self) -> None: - """ - Test that media that can not be retrieved will fail to be attached to a user - profile when legacy unrestricted media is disabled. - """ - # Generate non-existing media. - nonexistent_mxc_uri = MXCUri.from_str("mxc://remote/fakeMediaId_2") - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.user}/avatar_url", - access_token=self.tok, - content={"avatar_url": str(nonexistent_mxc_uri)}, - ) - - assert channel.code == HTTPStatus.NOT_FOUND, channel.json_body - assert channel.json_body["errcode"] == Codes.NOT_FOUND - def test_attaching_unrestricted_media_to_profile(self) -> None: """ Test that attaching unrestricted media to user profile also works @@ -1098,43 +1077,6 @@ def test_attaching_unrestricted_media_to_profile(self) -> None: ) assert channel.code == 200, channel.result - @override_config( - {"experimental_features": {"msc3911_unrestricted_media_upload_disabled": True}} - ) - def test_attaching_unrestricted_media_to_profile_fails(self) -> None: - """ - Test that attaching unrestricted media to user profile fails when unrestricted - media is banned by configuration. - """ - # Create unrestricted media. - content = io.BytesIO(SMALL_PNG) - content_uri = self.get_success( - self.media_repo.create_or_update_content( - "image/png", - "test_png_upload", - content, - 67, - UserID.from_string(self.user), - restricted=False, - ) - ) - - # Check media is unrestricted. - media_info = self.get_success(self.store.get_local_media(content_uri.media_id)) - assert media_info is not None - assert not media_info.restricted - - # Try to update user profile with unrestricted media. - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.user}/avatar_url", - access_token=self.tok, - content={"avatar_url": str(content_uri)}, - ) - assert channel.code == HTTPStatus.BAD_REQUEST, channel.json_body - assert channel.json_body["errcode"] == Codes.INVALID_PARAM - assert "is not restricted" in channel.json_body["error"] - def test_attaching_already_attached_media_to_profile_fails(self) -> None: """ Test that attaching already attached media to user profile fails. @@ -1327,7 +1269,9 @@ def prepare( def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) + # config["media_repo_instances"] = [MAIN_PROCESS_INSTANCE_NAME] + return config def create_media_and_set_restricted_flag(self, user_id: str) -> MXCUri: diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3def0eec06..b03063e1ab 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -4534,7 +4534,7 @@ def prepare( def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def create_media_and_set_restricted_flag( @@ -4842,7 +4842,7 @@ def prepare( def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def create_media_and_set_restricted_flag( @@ -5138,7 +5138,7 @@ def prepare( def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def create_media_and_set_restricted_flag( @@ -5269,53 +5269,6 @@ def test_create_room_fails_with_malformed_room_avatar_url(self) -> None: room_id = self.create_room_with_avatar(avatar_mxc="junk", expected_code=400) assert room_id is None - @override_config( - {"experimental_features": {"msc3911_unrestricted_media_upload_disabled": True}} - ) - def test_create_room_with_missing_profile_avatar_media_succeeds(self) -> None: - """ - Test that a profile avatar that should automatically be included in a room - creator's join event does not break the room when the actual media of the avatar - is missing. - """ - # First inject a profile avatar url directly into the database. The handler - # functions for such can not be used as they do validation, and it would fail as - # the media does not actually exist. - avatar_mxc_uri = MXCUri.from_str("mxc://fake-domain/whatever") - # Make sure to add the restrictions too - self.get_success_or_raise( - self.hs.get_datastores().main.set_media_restricted_to_user_profile( - avatar_mxc_uri.server_name, - avatar_mxc_uri.media_id, - self.user, - ) - ) - self.get_success_or_raise( - self.store.set_profile_avatar_url( - UserID.from_string(self.user), str(avatar_mxc_uri) - ) - ) - - # try and create room. This should succeed, but the avatar will have been - # stripped from the join event of the creator - room_id = self.helper.create_room_as( - self.user, - tok=self.tok, - ) - assert room_id is not None - # Make sure the avatar is not on the event - membership_as_set = self.get_success_or_raise( - self.store.get_membership_event_ids_for_user(self.user, room_id) - ) - join_event_id = membership_as_set.pop() - join_event = self.get_success_or_raise(self.store.get_event(join_event_id)) - assert join_event.content["membership"] == Membership.JOIN - assert "avatar_url" not in join_event.content - - # The display name would have been added, see if that is still there - assert "displayname" in join_event.content - assert join_event.content["displayname"] == "david" - class RoomMemberEventMediaAttachmentTestCase(unittest.HomeserverTestCase): servlets = [ @@ -5351,7 +5304,7 @@ def prepare( def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("experimental_features", {}) - config["experimental_features"].update({"msc3911_enabled": True}) + config["experimental_features"].update({"msc3911": {"enabled": True}}) return config def create_media_and_set_as_profile_avatar(self, user_id: str, tok: str) -> MXCUri: