diff --git a/server/Cargo.lock b/server/Cargo.lock index 26bcf06..38fd3a3 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -34,9 +34,9 @@ checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] name = "async-trait" -version = "0.1.82" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", @@ -1406,6 +1406,7 @@ name = "vncstream_server" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "axum", "cudarc", "ffmpeg-next", diff --git a/server/Cargo.toml b/server/Cargo.toml index 739aced..11a6121 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -31,6 +31,7 @@ cudarc = { version = "0.12.1", features = [ "cuda-11050" ] } tracing = "0.1.40" tracing-subscriber = "0.3.18" xkeysym = "0.2.1" +async-trait = "0.1.83" [patch.crates-io] diff --git a/server/src/main.rs b/server/src/main.rs index 0ccabcc..cd5f216 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -3,7 +3,12 @@ mod surface; mod types; mod video; +mod transport; + +use async_trait::async_trait; + use retro_thread::{spawn_retro_thread, RetroEvent}; +use transport::{Transport, TransportReciever}; use video::encoder_thread::EncodeThreadInput; use video::{encoder_thread, ffmpeg}; @@ -43,22 +48,135 @@ struct AppState { encoder_tx: Arc>>, inputs: Arc>>, - websocket_broadcast_tx: broadcast::Sender, - websocket_count: TokioMutex, + transport: Arc, + connection_count: TokioMutex, } impl AppState { - fn new(encoder_tx: mpsc::Sender) -> Self { - let (broadcast_tx, _) = broadcast::channel(10); + fn new( + encoder_tx: mpsc::Sender, + transport: Arc, + ) -> Self { Self { encoder_tx: Arc::new(TokioMutex::new(encoder_tx)), inputs: Arc::new(TokioMutex::new(Vec::new())), - websocket_broadcast_tx: broadcast_tx, - websocket_count: TokioMutex::const_new(0usize), + transport: transport, + connection_count: TokioMutex::const_new(0usize), } } } +#[async_trait] +impl TransportReciever for AppState { + async fn on_connect(&self, username: &String) -> anyhow::Result<()> { + println!("{username} joined!"); + + { + let mut lk = self.connection_count.lock().await; + *lk += 1; + } + + { + let locked = self.encoder_tx.lock().await; + + // Force a ws connection to mean a keyframe + let _ = locked.send(EncodeThreadInput::ForceKeyframe).await; + let _ = locked.send(EncodeThreadInput::SendFrame).await; + } + + Ok(()) + } + + async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()> { + if let Ok(json) = serde_json::from_str::(&message) { + if !json["type"].is_string() { + return Ok(()); + } + + match json["type"].as_str().unwrap() { + "chat" => { + if !json["msg"].is_string() { + return Ok(()); + } + + let send = serde_json::json!({ + "type": "chat", + "username": username, + "msg": json["msg"].as_str().unwrap() + }); + + + self.transport + .broadcast_message(transport::TransportMessage::Text( + serde_json::to_string(&send).expect("oh well"), + ))?; + } + + "key" => { + if !json["keysym"].is_number() { + return Ok(()); + } + + if !json["pressed"].is_number() { + return Ok(()); + } + + let keysym = json["keysym"].as_u64().unwrap() as u32; + let pressed = json["pressed"].as_u64().unwrap() == 1; + + // FIXME: This would be MUCH better off being a set, so we don't + // hack-code set semantics here. Oh well. + { + let mut lock = self.inputs.lock().await; + if pressed { + if let None = lock.iter().position(|e| *e == keysym) { + lock.push(keysym); + } + } else { + if let Some(at) = lock.iter().position(|e| *e == keysym) { + lock.remove(at); + } + } + } + } + + "mouse" => { + if json["x"].as_u64().is_none() { + return Ok(()); + } + + if json["y"].as_u64().is_none() { + return Ok(()); + } + + if json["mask"].as_u64().is_none() { + return Ok(()); + } + + //let x = json["x"].as_u64().unwrap() as u32; + //let y = json["y"].as_u64().unwrap() as u32; + //let mask = json["mask"].as_u64().unwrap() as u8; + } + _ => {} + } + } else { + return Ok(()); + } + + Ok(()) + } + + async fn on_leave(&self, username: &String) -> anyhow::Result<()> { + { + let mut lk = self.connection_count.lock().await; + *lk -= 1; + } + + println!("{username} left."); + Ok(()) + } +} + #[tokio::main(flavor = "multi_thread", worker_threads = 2)] async fn main() -> anyhow::Result<()> { // Setup a tracing subscriber @@ -74,12 +192,16 @@ async fn main() -> anyhow::Result<()> { let frame: Arc>> = Arc::new(Mutex::new(None)); let (mut encoder_rx, encoder_tx) = encoder_thread::encoder_thread_spawn(&frame); - let state = Arc::new(AppState::new(encoder_tx)); + let transport = Arc::new(crate::transport::websocket::WebsocketTransport::new()); + + let state = Arc::new(AppState::new(encoder_tx, transport.clone())); let (mut event_rx, event_in_tx) = spawn_retro_thread(surface.clone()); let state_clone = state.clone(); + let transport_clone = transport.clone(); + // retro event handler. drives the encoder thread too let _ = std::thread::Builder::new() .name("retro_event_rx".into()) @@ -168,9 +290,18 @@ async fn main() -> anyhow::Result<()> { match encoder_rx.try_recv() { Ok(msg) => match msg { encoder_thread::EncodeThreadOutput::Frame { packet } => { - let _ = state_clone - .websocket_broadcast_tx - .send(WsMessage::VideoPacket { packet }); + // let _ = state_clone + // .websocket_broadcast_tx + // .send(WsMessage::VideoPacket { packet }); + + // :( + let packet_data = { + let slice = packet.data().expect( + "should NOT be empty, this invariant is checked beforehand", + ); + slice.to_vec() + }; + let _ = transport_clone.broadcast_message(transport::TransportMessage::Binary(packet_data)); } }, Err(TryRecvError::Empty) => {} @@ -182,223 +313,6 @@ async fn main() -> anyhow::Result<()> { }) .expect("failed to spawn retro RX thread, it's probably over"); - // Axum websocket server - let app: Router<()> = Router::new() - .route( - "/", - get( - |ws: WebSocketUpgrade, - info: ConnectInfo, - state: State>| async move { - ws_handler(ws, info, state).await - }, - ), - ) - .with_state(state.clone()); - - let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940") - .await - .expect("failed to listen"); - - let axum_future = axum::serve( - tcp_listener, - app.into_make_service_with_connect_info::(), - ); - - // If the VNC client disconnects we should exit. - tokio::select! { - _ = axum_future => { - println!("axum died"); - } - - } - + transport.listen(state).await?; Ok(()) } - -async fn ws_handler( - ws: WebSocketUpgrade, - ConnectInfo(addr): ConnectInfo, - - State(state): State>, -) -> impl IntoResponse { - // finalize the upgrade process by returning upgrade callback. - // we can customize the callback by sending additional info such as address. - ws.on_upgrade(move |socket| handle_socket(socket, addr, state)) -} - -/// Actual websocket statemachine (one will be spawned per connection) -async fn handle_socket(socket: WebSocket, who: SocketAddr, state: Arc) { - let (mut sender, mut receiver) = socket.split(); - - // increment connection count - let inc_clone = Arc::clone(&state); - { - let mut lk = inc_clone.websocket_count.lock().await; - *lk += 1; - } - - { - let locked = state.encoder_tx.lock().await; - - // Force a ws connection to mean a keyframe - let _ = locked.send(EncodeThreadInput::ForceKeyframe).await; - let _ = locked.send(EncodeThreadInput::SendFrame).await; - } - - // random username - let username: Arc = - Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16)); - - println!("{username} ({who}) connected."); - - let send_clone = Arc::clone(&state); - let mut send_task = tokio::spawn(async move { - let mut sub = send_clone.websocket_broadcast_tx.subscribe(); - - while let Ok(msg) = sub.recv().await { - match msg { - WsMessage::VideoPacket { mut packet } => { - // :(. At least this copy doesn't occur on the driver threads anymore.. - let data = packet.data_mut().expect("shouldn't be taken"); - let msg = ws::Message::Binary(data.to_vec()); - if sender.send(msg).await.is_err() { - break; - } - } - - WsMessage::Json(s) => { - let msg = ws::Message::Text(s); - if sender.send(msg).await.is_err() { - break; - } - } - } - } - }); - - let username_clone = Arc::clone(&username); - - let recv_clone = Arc::clone(&state); - - let mut recv_task = tokio::spawn(async move { - while let Some(Ok(msg)) = receiver.next().await { - match msg { - Message::Text(msg) => { - // println!("{}", msg); - if let Ok(json) = serde_json::from_str::(&msg) { - if !json["type"].is_string() { - break; - } - - match json["type"].as_str().unwrap() { - "chat" => { - if !json["msg"].is_string() { - break; - } - - let send = serde_json::json!({ - "type": "chat", - "username": *username_clone, - "msg": json["msg"].as_str().unwrap() - }); - - recv_clone.websocket_broadcast_tx.send(WsMessage::Json( - serde_json::to_string(&send).expect("oh well"), - )); - - continue; - } - - "key" => { - if !json["keysym"].is_number() { - break; - } - - if !json["pressed"].is_number() { - break; - } - - let keysym = json["keysym"].as_u64().unwrap() as u32; - let pressed = json["pressed"].as_u64().unwrap() == 1; - - // FIXME: This would be MUCH better off being a set, so we don't - // hack-code set semantics here. Oh well. - { - let mut lock = recv_clone.inputs.lock().await; - if pressed { - if let None = lock.iter().position(|e| *e == keysym) { - lock.push(keysym); - } - } else { - if let Some(at) = lock.iter().position(|e| *e == keysym) { - lock.remove(at); - } - } - } - - /*let _ = recv_clone - .engine_tx - .send(vnc_engine::VncMessageInput::KeyEvent { - keysym: keysym, - pressed: pressed, - }) - .await;*/ - } - - "mouse" => { - if json["x"].as_u64().is_none() { - break; - } - - if json["y"].as_u64().is_none() { - break; - } - - if json["mask"].as_u64().is_none() { - break; - } - - //let x = json["x"].as_u64().unwrap() as u32; - //let y = json["y"].as_u64().unwrap() as u32; - //let mask = json["mask"].as_u64().unwrap() as u8; - - /*let _ = recv_clone - .engine_tx - .send(vnc_engine::VncMessageInput::MouseEvent { - pt: types::Point { x: x, y: y }, - buttons: mask, - }) - .await;*/ - } - _ => {} - } - } else { - break; - } - } - Message::Close(_) => break, - _ => {} - } - } - }); - - tokio::select! { - _ = (&mut send_task) => { - - recv_task.abort(); - }, - - _ = (&mut recv_task) => { - send_task.abort(); - } - } - - println!("{username} ({who}) left."); - - let dec_clone = Arc::clone(&state); - { - let mut lk = dec_clone.websocket_count.lock().await; - *lk -= 1; - } -} diff --git a/server/src/transport/mod.rs b/server/src/transport/mod.rs new file mode 100644 index 0000000..12e0027 --- /dev/null +++ b/server/src/transport/mod.rs @@ -0,0 +1,29 @@ +pub mod websocket; + +use std::sync::Arc; + +use async_trait::async_trait; + +#[derive(Clone, Debug)] +pub enum TransportMessage { + Text(String), + Binary(Vec) +} + +#[async_trait] +pub trait Transport : Send { + async fn listen(&self, iface: Arc) -> anyhow::Result<()>; + + /// Broadcasts a message. + fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()>; +} + +#[async_trait] +pub trait TransportReciever : Send { + async fn on_connect(&self, username: &String) -> anyhow::Result<()>; + + async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()>; + + async fn on_leave(&self, username: &String) -> anyhow::Result<()>; + +} \ No newline at end of file diff --git a/server/src/transport/websocket.rs b/server/src/transport/websocket.rs new file mode 100644 index 0000000..f4461b5 --- /dev/null +++ b/server/src/transport/websocket.rs @@ -0,0 +1,154 @@ +use super::{Transport, TransportMessage, TransportReciever}; + +use async_trait::async_trait; + +use std::sync::Arc; + +use rand::distributions::DistString; +use std::net::SocketAddr; +use tokio::sync::{ + broadcast, + mpsc::{self, error::TryRecvError}, + Mutex as TokioMutex, +}; + +use axum::{ + extract::{ + connect_info::ConnectInfo, + ws::{self, Message, WebSocket, WebSocketUpgrade}, + State, + }, + response::IntoResponse, + routing::get, + Router, +}; + +use futures::{sink::SinkExt, stream::StreamExt}; + +pub struct WebsocketTransport { + broadcast_tx: Arc>, +} + +async fn ws_handler( + ws: WebSocketUpgrade, + ConnectInfo(addr): ConnectInfo, + + State(state): State>, + mut broadcast_rx: broadcast::Receiver, +) -> impl IntoResponse { + // finalize the upgrade process by returning upgrade callback. + // we can customize the callback by sending additional info such as address. + ws.on_upgrade(move |socket| handle_socket(socket, addr, state, broadcast_rx)) +} + +/// Actual websocket statemachine (one will be spawned per connection) +async fn handle_socket( + socket: WebSocket, + who: SocketAddr, + state: Arc, + mut broadcast_rx: broadcast::Receiver, +) { + let (mut sender, mut receiver) = socket.split(); + + // random username + let username: Arc = + Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16)); + + let recv_clone = Arc::clone(&state); + + state.on_connect(&username).await; + + let mut send_task = tokio::spawn(async move { + while let Ok(msg) = broadcast_rx.recv().await { + match msg { + TransportMessage::Text(b) => { + if sender.send(Message::Text(b)).await.is_err() { + break; + } + } + TransportMessage::Binary(b) => { + if sender.send(Message::Binary(b)).await.is_err() { + break; + } + } + } + } + }); + + let username_clone = username.clone(); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + match msg { + Message::Text(msg) => { + recv_clone.on_message(&username_clone, &msg).await; + } + Message::Close(_) => break, + _ => {} + } + } + }); + + tokio::select! { + _ = (&mut send_task) => { + recv_task.abort(); + }, + + _ = (&mut recv_task) => { + send_task.abort(); + } + } + + state.on_leave(&username).await; +} + +impl WebsocketTransport { + pub fn new() -> Self { + let (broadcast_tx, _) = broadcast::channel(32); + + Self { + broadcast_tx: Arc::new(broadcast_tx), + } + } +} + +#[async_trait] +impl Transport for WebsocketTransport { + async fn listen( + &self, + iface: Arc, + ) -> anyhow::Result<()> { + let tx_clone = self.broadcast_tx.clone(); + + // Axum websocket server + let app: Router<()> = + Router::new() + .route( + "/", + get( + |ws: WebSocketUpgrade, + info: ConnectInfo, + state: State>| async move { + ws_handler(ws, info, state, tx_clone.subscribe()).await + }, + ), + ) + .with_state(iface.clone()); + + let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940") + .await + .expect("failed to listen"); + + let axum_future = axum::serve( + tcp_listener, + app.into_make_service_with_connect_info::(), + ); + + axum_future.await?; + Ok(()) + } + + fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()> { + let _ = self.broadcast_tx.send(m); + Ok(()) + } +}