Skip to content
Open
5 changes: 2 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ UVICORN_PORT = 8000
# USER_SUBSCRIPTION_CLIENTS_LIMIT = 10

# CUSTOM_TEMPLATES_DIRECTORY="/var/lib/pasarguard/templates/"
# CLASH_SUBSCRIPTION_TEMPLATE="clash/my-custom-template.yml"
# SUBSCRIPTION_PAGE_TEMPLATE="subscription/index.html"
# HOME_PAGE_TEMPLATE="home/index.html"
# XRAY_SUBSCRIPTION_TEMPLATE="xray/default.json"
# SINGBOX_SUBSCRIPTION_TEMPLATE="singbox/default.json"
# Core subscription templates are stored in DB table `core_templates`
# and managed via `/api/core_template`.

## External config to import into v2ray format subscription
# EXTERNAL_CONFIG = "config://..."
Expand Down
8 changes: 6 additions & 2 deletions app/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from app.nats.message import MessageTopic
from app.nats.router import router
from app.settings import handle_settings_message
from app.subscription.client_templates import handle_client_template_message
from app.utils.logger import get_logger
from app.version import __version__
from config import DOCS, ROLE, SUBSCRIPTION_PATH
Expand All @@ -24,12 +25,14 @@ def _use_route_names_as_operation_ids(app: FastAPI) -> None:
route.operation_id = route.name


def _register_nats_handlers(enable_router: bool, enable_settings: bool):
def _register_nats_handlers(enable_router: bool, enable_settings: bool, enable_client_templates: bool):
if enable_router:
on_startup(router.start)
on_shutdown(router.stop)
if enable_settings:
router.register_handler(MessageTopic.SETTING, handle_settings_message)
if enable_client_templates:
router.register_handler(MessageTopic.CLIENT_TEMPLATE, handle_client_template_message)


def _register_scheduler_hooks():
Expand Down Expand Up @@ -105,7 +108,8 @@ def _validate_paths():

enable_router = ROLE.runs_panel or ROLE.runs_node or ROLE.runs_scheduler
enable_settings = ROLE.runs_panel or ROLE.runs_scheduler
_register_nats_handlers(enable_router, enable_settings)
enable_client_templates = ROLE.runs_panel or ROLE.runs_scheduler
_register_nats_handlers(enable_router, enable_settings, enable_client_templates)
_register_scheduler_hooks()
_register_jobs()

Expand Down
2 changes: 2 additions & 0 deletions app/db/crud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .admin import get_admin
from .core import get_core_config_by_id
from .client_template import get_client_template_by_id
from .group import get_group_by_id
from .host import get_host_by_id
from .node import get_node_by_id
Expand All @@ -10,6 +11,7 @@
__all__ = [
"get_admin",
"get_core_config_by_id",
"get_client_template_by_id",
"get_group_by_id",
"get_host_by_id",
"get_node_by_id",
Expand Down
251 changes: 251 additions & 0 deletions app/db/crud/client_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
from collections import defaultdict
from collections.abc import Mapping
from enum import Enum

from sqlalchemy import func, select, update
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession

from app.db.models import ClientTemplate
from app.models.client_template import ClientTemplateCreate, ClientTemplateModify, ClientTemplateType
from app.subscription.default_templates import DEFAULT_TEMPLATE_CONTENTS_BY_LEGACY_KEY

TEMPLATE_TYPE_TO_LEGACY_KEY: dict[ClientTemplateType, str] = {
ClientTemplateType.clash_subscription: "CLASH_SUBSCRIPTION_TEMPLATE",
ClientTemplateType.xray_subscription: "XRAY_SUBSCRIPTION_TEMPLATE",
ClientTemplateType.singbox_subscription: "SINGBOX_SUBSCRIPTION_TEMPLATE",
ClientTemplateType.user_agent: "USER_AGENT_TEMPLATE",
ClientTemplateType.grpc_user_agent: "GRPC_USER_AGENT_TEMPLATE",
}

ClientTemplateSortingOptionsSimple = Enum(
"ClientTemplateSortingOptionsSimple",
{
"id": ClientTemplate.id.asc(),
"-id": ClientTemplate.id.desc(),
"name": ClientTemplate.name.asc(),
"-name": ClientTemplate.name.desc(),
"type": ClientTemplate.template_type.asc(),
"-type": ClientTemplate.template_type.desc(),
},
)


def get_default_client_template_contents() -> dict[str, str]:
return DEFAULT_TEMPLATE_CONTENTS_BY_LEGACY_KEY.copy()


def merge_client_template_values(values: Mapping[str, str] | None = None) -> dict[str, str]:
merged = get_default_client_template_contents()
if not values:
return merged

for key, value in values.items():
if key in merged and value:
merged[key] = value

return merged


async def get_client_template_values(db: AsyncSession) -> dict[str, str]:
defaults = get_default_client_template_contents()
try:
rows = (
await db.execute(
select(
ClientTemplate.id,
ClientTemplate.template_type,
ClientTemplate.content,
ClientTemplate.is_default,
).order_by(ClientTemplate.template_type.asc(), ClientTemplate.id.asc())
)
).all()
except SQLAlchemyError:
return defaults

by_type: dict[str, list[tuple[int, str, bool]]] = defaultdict(list)
for row in rows:
by_type[row.template_type].append((row.id, row.content, row.is_default))

values: dict[str, str] = {}
for template_type, legacy_key in TEMPLATE_TYPE_TO_LEGACY_KEY.items():
type_rows = by_type.get(template_type.value, [])
if not type_rows:
continue

selected_content = ""
for _, content, is_default in type_rows:
if is_default:
selected_content = content
break

if not selected_content:
selected_content = type_rows[0][1]

if selected_content:
values[legacy_key] = selected_content

return merge_client_template_values(values)


async def get_client_template_by_id(db: AsyncSession, template_id: int) -> ClientTemplate | None:
return (await db.execute(select(ClientTemplate).where(ClientTemplate.id == template_id))).unique().scalar_one_or_none()


async def get_client_templates(
db: AsyncSession,
template_type: ClientTemplateType | None = None,
offset: int | None = None,
limit: int | None = None,
) -> tuple[list[ClientTemplate], int]:
query = select(ClientTemplate)
if template_type is not None:
query = query.where(ClientTemplate.template_type == template_type.value)

total = (await db.execute(select(func.count()).select_from(query.subquery()))).scalar() or 0

query = query.order_by(ClientTemplate.template_type.asc(), ClientTemplate.id.asc())
if offset:
query = query.offset(offset)
if limit:
query = query.limit(limit)

rows = (await db.execute(query)).scalars().all()
return rows, total


async def get_client_templates_simple(
db: AsyncSession,
offset: int | None = None,
limit: int | None = None,
search: str | None = None,
template_type: ClientTemplateType | None = None,
sort: list[ClientTemplateSortingOptionsSimple] | None = None,
skip_pagination: bool = False,
) -> tuple[list[tuple[int, str, str, bool]], int]:
stmt = select(ClientTemplate.id, ClientTemplate.name, ClientTemplate.template_type, ClientTemplate.is_default)

if search:
stmt = stmt.where(ClientTemplate.name.ilike(f"%{search.strip()}%"))

if template_type is not None:
stmt = stmt.where(ClientTemplate.template_type == template_type.value)

if sort:
sort_list = []
for s in sort:
if isinstance(s.value, tuple):
sort_list.extend(s.value)
else:
sort_list.append(s.value)
stmt = stmt.order_by(*sort_list)
else:
stmt = stmt.order_by(ClientTemplate.template_type.asc(), ClientTemplate.id.asc())

total = (await db.execute(select(func.count()).select_from(stmt.subquery()))).scalar() or 0

if not skip_pagination:
if offset:
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
else:
stmt = stmt.limit(10000)

rows = (await db.execute(stmt)).all()
return rows, total


async def count_client_templates_by_type(db: AsyncSession, template_type: ClientTemplateType) -> int:
count_stmt = select(func.count()).select_from(ClientTemplate).where(ClientTemplate.template_type == template_type.value)
return (await db.execute(count_stmt)).scalar() or 0


async def get_first_template_by_type(
db: AsyncSession,
template_type: ClientTemplateType,
exclude_id: int | None = None,
) -> ClientTemplate | None:
stmt = (
select(ClientTemplate)
.where(ClientTemplate.template_type == template_type.value)
.order_by(ClientTemplate.id.asc())
)
if exclude_id is not None:
stmt = stmt.where(ClientTemplate.id != exclude_id)
return (await db.execute(stmt)).scalars().first()


async def set_default_template(db: AsyncSession, db_template: ClientTemplate) -> ClientTemplate:
await db.execute(
update(ClientTemplate)
.where(ClientTemplate.template_type == db_template.template_type)
.values(is_default=False)
)
db_template.is_default = True
await db.commit()
await db.refresh(db_template)
return db_template


async def create_client_template(db: AsyncSession, client_template: ClientTemplateCreate) -> ClientTemplate:
type_count = await count_client_templates_by_type(db, client_template.template_type)
is_first_for_type = type_count == 0
should_be_default = client_template.is_default or is_first_for_type

if should_be_default:
await db.execute(
update(ClientTemplate)
.where(ClientTemplate.template_type == client_template.template_type.value)
.values(is_default=False)
)

db_template = ClientTemplate(
name=client_template.name,
template_type=client_template.template_type.value,
content=client_template.content,
is_default=should_be_default,
is_system=is_first_for_type,
)
db.add(db_template)
try:
await db.commit()
except IntegrityError:
await db.rollback()
raise
await db.refresh(db_template)
return db_template


async def modify_client_template(
db: AsyncSession,
db_template: ClientTemplate,
modified_template: ClientTemplateModify,
) -> ClientTemplate:
template_data = modified_template.model_dump(exclude_none=True)

if modified_template.is_default is True:
await db.execute(
update(ClientTemplate)
.where(ClientTemplate.template_type == db_template.template_type)
.values(is_default=False)
)
db_template.is_default = True

if "name" in template_data:
db_template.name = template_data["name"]
if "content" in template_data:
db_template.content = template_data["content"]

try:
await db.commit()
except IntegrityError:
await db.rollback()
raise
await db.refresh(db_template)
return db_template


async def remove_client_template(db: AsyncSession, db_template: ClientTemplate) -> None:
await db.delete(db_template)
await db.commit()
Loading