[go: nahoru, domu]

Skip to content

Commit

Permalink
Refactor and add origin check to SIO
Browse files Browse the repository at this point in the history
  • Loading branch information
deiteris committed Mar 18, 2024
1 parent ce9b599 commit 8dd8d71
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 28 deletions.
4 changes: 2 additions & 2 deletions server/MMVCServerSIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def localServer(logLevel: str = "critical", key_path: str | None = None, cert_pa
mp.freeze_support()

voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT)


if __name__ == "__mp_main__":
Expand Down
24 changes: 24 additions & 0 deletions server/mods/origins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Optional, Sequence
from urllib.parse import urlparse

ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
SCHEMAS = ('http', 'https')
LOCAL_ORIGINS = ('127.0.0.1', 'localhost')

def compute_local_origins(port: Optional[int] = None) -> list[str]:
local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS]
if port is not None:
local_origins = [f'{origin}:{port}' for origin in local_origins]
return local_origins


def normalize_origins(origins: Sequence[str]) -> set[str]:
allowed_origins = set()
for origin in origins:
url = urlparse(origin)
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
valid_origin = f'{url.scheme}://{url.hostname}'
if url.port:
valid_origin += f':{url.port}'
allowed_origins.add(valid_origin)
return allowed_origins
6 changes: 3 additions & 3 deletions server/restapi/MMVC_Rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from fastapi.exceptions import RequestValidationError
from typing import Callable
from typing import Callable, Optional, Sequence, Literal
from mods.log_control import VoiceChangaerLogger
from voice_changer.VoiceChangerManager import VoiceChangerManager

Expand Down Expand Up @@ -43,8 +43,8 @@ def get_instance(
cls,
voiceChangerManager: VoiceChangerManager,
voiceChangerParams: VoiceChangerParams,
port: int,
allowedOrigins: list[str],
allowedOrigins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
):
if cls._instance is None:
logger.info("[Voice Changer] MMVC_Rest initializing...")
Expand Down
30 changes: 11 additions & 19 deletions server/restapi/mods/trustedorigin.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
import typing
from typing import Optional, Sequence, Literal

from urllib.parse import urlparse
from mods.origins import compute_local_origins, normalize_origins
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Receive, Scope, Send

ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."


class TrustedOriginMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
port: typing.Optional[int] = None,
allowed_origins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
) -> None:
schemas = ['http', 'https']
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
if port is not None:
local_origins = [f'{origin}:{port}' for origin in local_origins]

self.allowed_origins: set[str] = set()
if allowed_origins is not None:
for origin in allowed_origins:
url = urlparse(origin)
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
valid_origin = f'{url.scheme}://{url.hostname}'
if url.port:
valid_origin += f':{url.port}'
self.allowed_origins.add(valid_origin)

local_origins = compute_local_origins(port)
self.allowed_origins.update(local_origins)

if allowed_origins is not None:
normalized_origins = normalize_origins(allowed_origins)
self.allowed_origins.update(normalized_origins)

self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down
20 changes: 18 additions & 2 deletions server/sio/MMVC_SocketIOApp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import socketio
from mods.log_control import VoiceChangaerLogger
from mods.origins import compute_local_origins, normalize_origins

from typing import Sequence, Optional
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
from voice_changer.VoiceChangerManager import VoiceChangerManager
from const import getFrontendPath
Expand All @@ -12,10 +14,24 @@ class MMVC_SocketIOApp:
_instance: socketio.ASGIApp | None = None

@classmethod
def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager):
def get_instance(
cls,
app_fastapi,
voiceChangerManager: VoiceChangerManager,
allowedOrigins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
):
if cls._instance is None:
logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager)

allowed_origins: set[str] = set()
local_origins = compute_local_origins(port)
allowed_origins.update(local_origins)
if allowedOrigins is not None:
normalized_origins = normalize_origins(allowedOrigins)
allowed_origins.update(normalized_origins)
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins))

app_socketio = socketio.ASGIApp(
sio,
other_asgi_app=app_fastapi,
Expand Down
8 changes: 6 additions & 2 deletions server/sio/MMVC_SocketIOServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ class MMVC_SocketIOServer:
_instance: socketio.AsyncServer | None = None

@classmethod
def get_instance(cls, voiceChangerManager: VoiceChangerManager):
def get_instance(
cls,
voiceChangerManager: VoiceChangerManager,
allowedOrigins: list[str],
):
if cls._instance is None:
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins)
namespace = MMVC_Namespace.get_instance(voiceChangerManager)
sio.register_namespace(namespace)
cls._instance = sio
Expand Down

0 comments on commit 8dd8d71

Please sign in to comment.