diff options
Diffstat (limited to 'webserver.py')
| -rw-r--r-- | webserver.py | 214 | 
1 files changed, 151 insertions, 63 deletions
| diff --git a/webserver.py b/webserver.py index d1c0a1e..162badc 100644 --- a/webserver.py +++ b/webserver.py @@ -1,42 +1,43 @@  #!/usr/bin/env python -  """  Module serving video from zmq to a webserver.  """ -  __author__ = "Franoosh Corporation" -  import os -from collections import defaultdict +from collections import defaultdict, deque  import json  import logging +import zmq.asyncio  import asyncio -from threading import Thread +from collections import defaultdict +from contextlib import asynccontextmanager  import uvicorn +import zmq  from fastapi import (      FastAPI,      Request,      HTTPException,      WebSocket, +    WebSocketDisconnect,      templating,  )  from fastapi.responses import HTMLResponse  from fastapi.staticfiles import StaticFiles  from helpers import CustomLoggingFormatter -import zmq  CLIENTS_JSON_FILE = os.path.join(os.getcwd(), 'clients.json') +CLIENTS_DICT = defaultdict(list)  # CLIENTS_DICT[client_id] = [camera_id1, camera_id2, ...]   LOGFILE = 'webserver.log' -LOGLEVEL = logging.INFO +LOGLEVEL = logging.DEBUG  HOST = "127.0.0.1" -ZMQPORT = "9979" -WSPORT = "8008" -ZMQ_BACKEND_ADDR = f"tcp://{HOST}:{ZMQPORT}" -WS_BACKEND_ADDR = f"tcp://{HOST}:{WSPORT}" +ZMQ_PORT = "9979" +WEB_PORT = "8008" +CTRL_BACKEND_ADDR = f"tcp://{HOST}:{ZMQ_PORT}" +WEB_BACKEND_ADDR = f"tcp://{HOST}:{WEB_PORT}"  log_formatter = CustomLoggingFormatter()  handler = logging.FileHandler(LOGFILE, encoding='utf-8', mode='a') @@ -50,88 +51,175 @@ logging.basicConfig(      level=LOGLEVEL,  ) - - -app = FastAPI() +# Track websocket connections by (client_id, camera_id): +ws_connections = defaultdict(dict)  # ws_connections[client_id][camera_id] = websocket +ws_queues = defaultdict(dict) +ctrl_msg_que = asyncio.Queue() +# Create ZMQ context and socket: +zmq_context = zmq.asyncio.Context() +zmq_socket = zmq_context.socket(zmq.DEALER) +# Connect to ZMQ backend: +zmq_socket.connect(WEB_BACKEND_ADDR) + + +async def zmq_bridge(): +    """Bridge between ZMQ backend and websocket clients.""" +    while True: +        try: +            data = await zmq_socket.recv_multipart() +            topic, frame_data = None, None +            if len(data) == 2: +                topic, frame_data = data +            elif len(data) == 3: +                topic, _, _ = data +            else: +                logger.warning("Received invalid ZMQ message: %r", data) +                continue +            if topic: +                client_id, camera_id = topic.decode('utf-8').split(':', 1) +                if frame_data: +                    # Add client and camera to CLIENTS_DICT if new: +                    if not camera_id in CLIENTS_DICT[client_id]: +                        CLIENTS_DICT[client_id].append(camera_id) +                else: +                    # No frame data means a notification to remove camera: +                    try: +                        CLIENTS_DICT[client_id].remove(camera_id) +                    except ValueError: +                        pass + +            queue = ws_queues.get(client_id, {}).get(camera_id) +            if queue and frame_data: +                if queue.full(): +                    _ = queue.get_nowait()  # Discard oldest frame +                await queue.put(frame_data) +            if not ctrl_msg_que.empty(): +                client_id, camera_id, command, args_list = await ctrl_msg_que.get() +                zmq_socket.send_multipart([ +                    client_id.encode('utf-8'), +                    camera_id.encode('utf-8'), +                    command.encode('utf-8'), +                ] +  args_list) +                logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id) +        except Exception as e: +            logger.error("Error in ZMQ bridge: %r", e) +    # TODO: Check if this loop can be optimized to avoid busy waiting. +    # Alternative implementation using zmq Poller: +    # poll = zmq.asyncio.Poller() +    # poll.register(zmq_socket, zmq.POLLIN) +    # while True: +    #     try: +    #         sockets = dict(await poll.poll()) +    #         if zmq_socket in sockets: +    #             topic, frame_data = await zmq_socket.recv_multipart() +    #             client_id, camera_id = topic.decode('utf-8').split(':', 1) +    #             set_clients(client_id, camera_id) +    #             queue = ws_queues.get(client_id, {}).get(camera_id) +    #             if queue: +    #                 if queue.full(): +    #                     _ = queue.get_nowait()  # Discard oldest frame +    #                 await queue.put(frame_data) +    #         if not ctrl_msg_que.empty(): +    #             client_id, camera_id, command, args_list = await ctrl_msg_que.get() +    #             zmq_socket.send_multipart([ +    #                 client_id.encode('utf-8'), +    #                 camera_id.encode('utf-8'), +    #                 command.encode('utf-8'), +    #             ] +  args_list) +    #             logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id) +    #     except Exception as e: +    #         logger.error("Error in ZMQ bridge: %r", e) + +@asynccontextmanager +async def lifespan(app: FastAPI): +    """Create lifespan context for FastAPI app.""" +    asyncio.create_task(zmq_bridge()) +    yield + +app = FastAPI(lifespan=lifespan)  app.mount("/static", StaticFiles(directory="static"), name="static")  templates = templating.Jinja2Templates(directory='templates') -# Track websocket connections by (client_id, camera_id) -ws_connections = defaultdict(dict)  # ws_connections[client_id][camera_id] = websocket - -# Set up a single ZMQ SUB socket for all websocket connections -zmq_context = zmq.Context() -zmq_socket = zmq_context.socket(zmq.SUB) -zmq_socket.bind(WS_BACKEND_ADDR) -zmq_socket.setsockopt(zmq.SUBSCRIBE, b"")  # Subscribe to all topics -poller = zmq.Poller() -poller.register(zmq_socket, zmq.POLLIN) - -def load_clients(): -    try: -        with open(CLIENTS_JSON_FILE) as f: -            clients_dict = json.load(f) -    except FileNotFoundError: -        clients_dict = {} -    return clients_dict  @app.get("/")  async def main_route(request: Request): -    logger.error("DEBUG: main route visited") -    clients = load_clients() +    """Serve main page.""" +    logger.debug("Main route visited")      return templates.TemplateResponse(          "main.html",          {              "request": request, -            "clients": clients, +            "clients": CLIENTS_DICT,           }      )  @app.get("/clients/{client_id}", response_class=HTMLResponse)  async def client_route(request: Request, client_id: str):      """Serve client page.""" -    clients_dict = load_clients() -    logger.debug("Checking client_id: '%s' in clients_dict: %r.", client_id, clients_dict) -    if not client_id in clients_dict: -        return HTTPException(status_code=404, detail="No such client ID.") +    logger.debug("Checking client_id: '%s' in clients_dict: %r.", client_id, CLIENTS_DICT) +    if not client_id in CLIENTS_DICT: +        raise HTTPException(status_code=404, detail="No such client ID.")      return templates.TemplateResponse(          "client.html",          {              "request": request,              "client_id": client_id, -            "camera_ids": clients_dict[client_id], +            "camera_ids": CLIENTS_DICT[client_id],          },      ) -  @app.websocket("/ws/{client_id}/{camera_id}")  async def camera_route(websocket: WebSocket, client_id: str, camera_id: str):      """Serve a particular camera page."""      logger.info("Accepting websocket connection for '/ws/%s/%s'.", client_id, camera_id)      await websocket.accept() -    if client_id not in ws_connections: -        ws_connections[client_id] = {} -    ws_connections[client_id][camera_id] = websocket -    try: +    ws_connections[client_id][camera_id] = {'ws': websocket} +    ws_queues[client_id][camera_id] = asyncio.Queue(maxsize=10) +    queue = ws_queues[client_id][camera_id] +     +    async def send_frames(): +        while True: +            frame_data = await queue.get() +            await websocket.send_bytes(frame_data) +     +    async def receive_control():          while True: -            # Wait for a frame for this client/camera -            sockets = dict(poller.poll(1000)) -            if zmq_socket in sockets: -                msg = zmq_socket.recv_multipart() -                if len(msg) == 3: -                    recv_client_id, recv_camera_id, content = msg -                    recv_client_id = recv_client_id.decode("utf-8") -                    recv_camera_id = recv_camera_id.decode("utf-8") -                    # Only send to the websocket for this client/camera -                    if recv_client_id == client_id and recv_camera_id == camera_id: -                        await websocket.send_bytes(content) -    except Exception as exc: -        logger.warning("Connection closed: %r", exc) +            try: +                data = await websocket.receive_text() +                logger.info("Received control message from '/ws/%s/%s': %s", client_id, camera_id, data) +                # Handle control messages from the client: +                frontend_message = json.loads(data) +                for command, args in frontend_message.items(): +                    args_list = [str(arg).encode('utf-8') for arg in args] +                    ctrl_msg_que.put_nowait((client_id, camera_id, command, args_list)) +                    # zmq_control_socket.send_multipart([ +                    #     client_id.encode('utf-8'), +                    #     camera_id.encode('utf-8'), +                    #     command.encode('utf-8'), +                    #     b'', +                    # ] + args_list) +                    # logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id) +                    logger.info("Put control command '%s' with args: %r for '/ws/%s/%s' on queue to backend.", command, args_list, client_id, camera_id) +            except json.JSONDecodeError: +                logger.warning("Received invalid JSON from '/ws/%s/%s': %s", client_id, camera_id, data) +            except WebSocketDisconnect: +                logger.info("WebSocket disconnected for '/ws/%s/%s'.", client_id, camera_id) +                break +            except Exception as exc: +                logger.warning("Error receiving control message: %r", exc) +    send_task = asyncio.create_task(send_frames()) +    receive_task = asyncio.create_task(receive_control()) +    try: +        # await asyncio.gather(send_task, receive_task) +        await asyncio.wait( +            [send_task, receive_task], +            return_when=asyncio.FIRST_COMPLETED, +        )      finally: -        if client_id in ws_connections and camera_id in ws_connections[client_id]: -            del ws_connections[client_id][camera_id] -        await websocket.close() - +        send_task.cancel() +        receive_task.cancel() +        ws_connections[client_id].pop(camera_id, None) +        ws_queues[client_id].pop(camera_id, None)  if __name__ == "__main__":      uvicorn.run( @@ -139,4 +227,4 @@ if __name__ == "__main__":          port=8007,          host='127.0.0.1',          log_level='info', -    )
\ No newline at end of file +    ) | 
