refactor to use a set of traits for transport

(i want to try and make this use webtransport)
This commit is contained in:
Lily Tsuru 2024-10-11 05:09:12 -04:00
parent 8438ca11b5
commit 9248fe91a9
5 changed files with 329 additions and 230 deletions

5
server/Cargo.lock generated
View file

@ -34,9 +34,9 @@ checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da"
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.82" version = "0.1.83"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1406,6 +1406,7 @@ name = "vncstream_server"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"axum", "axum",
"cudarc", "cudarc",
"ffmpeg-next", "ffmpeg-next",

View file

@ -31,6 +31,7 @@ cudarc = { version = "0.12.1", features = [ "cuda-11050" ] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = "0.3.18" tracing-subscriber = "0.3.18"
xkeysym = "0.2.1" xkeysym = "0.2.1"
async-trait = "0.1.83"
[patch.crates-io] [patch.crates-io]

View file

@ -3,7 +3,12 @@ mod surface;
mod types; mod types;
mod video; mod video;
mod transport;
use async_trait::async_trait;
use retro_thread::{spawn_retro_thread, RetroEvent}; use retro_thread::{spawn_retro_thread, RetroEvent};
use transport::{Transport, TransportReciever};
use video::encoder_thread::EncodeThreadInput; use video::encoder_thread::EncodeThreadInput;
use video::{encoder_thread, ffmpeg}; use video::{encoder_thread, ffmpeg};
@ -43,22 +48,135 @@ struct AppState {
encoder_tx: Arc<TokioMutex<mpsc::Sender<EncodeThreadInput>>>, encoder_tx: Arc<TokioMutex<mpsc::Sender<EncodeThreadInput>>>,
inputs: Arc<TokioMutex<Vec<u32>>>, inputs: Arc<TokioMutex<Vec<u32>>>,
websocket_broadcast_tx: broadcast::Sender<WsMessage>, transport: Arc<crate::transport::websocket::WebsocketTransport>,
websocket_count: TokioMutex<usize>, connection_count: TokioMutex<usize>,
} }
impl AppState { impl AppState {
fn new(encoder_tx: mpsc::Sender<EncodeThreadInput>) -> Self { fn new(
let (broadcast_tx, _) = broadcast::channel(10); encoder_tx: mpsc::Sender<EncodeThreadInput>,
transport: Arc<crate::transport::websocket::WebsocketTransport>,
) -> Self {
Self { Self {
encoder_tx: Arc::new(TokioMutex::new(encoder_tx)), encoder_tx: Arc::new(TokioMutex::new(encoder_tx)),
inputs: Arc::new(TokioMutex::new(Vec::new())), inputs: Arc::new(TokioMutex::new(Vec::new())),
websocket_broadcast_tx: broadcast_tx, transport: transport,
websocket_count: TokioMutex::const_new(0usize), connection_count: TokioMutex::const_new(0usize),
} }
} }
} }
#[async_trait]
impl TransportReciever for AppState {
async fn on_connect(&self, username: &String) -> anyhow::Result<()> {
println!("{username} joined!");
{
let mut lk = self.connection_count.lock().await;
*lk += 1;
}
{
let locked = self.encoder_tx.lock().await;
// Force a ws connection to mean a keyframe
let _ = locked.send(EncodeThreadInput::ForceKeyframe).await;
let _ = locked.send(EncodeThreadInput::SendFrame).await;
}
Ok(())
}
async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()> {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&message) {
if !json["type"].is_string() {
return Ok(());
}
match json["type"].as_str().unwrap() {
"chat" => {
if !json["msg"].is_string() {
return Ok(());
}
let send = serde_json::json!({
"type": "chat",
"username": username,
"msg": json["msg"].as_str().unwrap()
});
self.transport
.broadcast_message(transport::TransportMessage::Text(
serde_json::to_string(&send).expect("oh well"),
))?;
}
"key" => {
if !json["keysym"].is_number() {
return Ok(());
}
if !json["pressed"].is_number() {
return Ok(());
}
let keysym = json["keysym"].as_u64().unwrap() as u32;
let pressed = json["pressed"].as_u64().unwrap() == 1;
// FIXME: This would be MUCH better off being a set, so we don't
// hack-code set semantics here. Oh well.
{
let mut lock = self.inputs.lock().await;
if pressed {
if let None = lock.iter().position(|e| *e == keysym) {
lock.push(keysym);
}
} else {
if let Some(at) = lock.iter().position(|e| *e == keysym) {
lock.remove(at);
}
}
}
}
"mouse" => {
if json["x"].as_u64().is_none() {
return Ok(());
}
if json["y"].as_u64().is_none() {
return Ok(());
}
if json["mask"].as_u64().is_none() {
return Ok(());
}
//let x = json["x"].as_u64().unwrap() as u32;
//let y = json["y"].as_u64().unwrap() as u32;
//let mask = json["mask"].as_u64().unwrap() as u8;
}
_ => {}
}
} else {
return Ok(());
}
Ok(())
}
async fn on_leave(&self, username: &String) -> anyhow::Result<()> {
{
let mut lk = self.connection_count.lock().await;
*lk -= 1;
}
println!("{username} left.");
Ok(())
}
}
#[tokio::main(flavor = "multi_thread", worker_threads = 2)] #[tokio::main(flavor = "multi_thread", worker_threads = 2)]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
// Setup a tracing subscriber // Setup a tracing subscriber
@ -74,12 +192,16 @@ async fn main() -> anyhow::Result<()> {
let frame: Arc<Mutex<Option<ffmpeg::frame::Video>>> = Arc::new(Mutex::new(None)); let frame: Arc<Mutex<Option<ffmpeg::frame::Video>>> = Arc::new(Mutex::new(None));
let (mut encoder_rx, encoder_tx) = encoder_thread::encoder_thread_spawn(&frame); let (mut encoder_rx, encoder_tx) = encoder_thread::encoder_thread_spawn(&frame);
let state = Arc::new(AppState::new(encoder_tx)); let transport = Arc::new(crate::transport::websocket::WebsocketTransport::new());
let state = Arc::new(AppState::new(encoder_tx, transport.clone()));
let (mut event_rx, event_in_tx) = spawn_retro_thread(surface.clone()); let (mut event_rx, event_in_tx) = spawn_retro_thread(surface.clone());
let state_clone = state.clone(); let state_clone = state.clone();
let transport_clone = transport.clone();
// retro event handler. drives the encoder thread too // retro event handler. drives the encoder thread too
let _ = std::thread::Builder::new() let _ = std::thread::Builder::new()
.name("retro_event_rx".into()) .name("retro_event_rx".into())
@ -168,9 +290,18 @@ async fn main() -> anyhow::Result<()> {
match encoder_rx.try_recv() { match encoder_rx.try_recv() {
Ok(msg) => match msg { Ok(msg) => match msg {
encoder_thread::EncodeThreadOutput::Frame { packet } => { encoder_thread::EncodeThreadOutput::Frame { packet } => {
let _ = state_clone // let _ = state_clone
.websocket_broadcast_tx // .websocket_broadcast_tx
.send(WsMessage::VideoPacket { packet }); // .send(WsMessage::VideoPacket { packet });
// :(
let packet_data = {
let slice = packet.data().expect(
"should NOT be empty, this invariant is checked beforehand",
);
slice.to_vec()
};
let _ = transport_clone.broadcast_message(transport::TransportMessage::Binary(packet_data));
} }
}, },
Err(TryRecvError::Empty) => {} Err(TryRecvError::Empty) => {}
@ -182,223 +313,6 @@ async fn main() -> anyhow::Result<()> {
}) })
.expect("failed to spawn retro RX thread, it's probably over"); .expect("failed to spawn retro RX thread, it's probably over");
// Axum websocket server transport.listen(state).await?;
let app: Router<()> = Router::new()
.route(
"/",
get(
|ws: WebSocketUpgrade,
info: ConnectInfo<SocketAddr>,
state: State<Arc<AppState>>| async move {
ws_handler(ws, info, state).await
},
),
)
.with_state(state.clone());
let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940")
.await
.expect("failed to listen");
let axum_future = axum::serve(
tcp_listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
);
// If the VNC client disconnects we should exit.
tokio::select! {
_ = axum_future => {
println!("axum died");
}
}
Ok(()) Ok(())
} }
async fn ws_handler(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
}
/// Actual websocket statemachine (one will be spawned per connection)
async fn handle_socket(socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
// increment connection count
let inc_clone = Arc::clone(&state);
{
let mut lk = inc_clone.websocket_count.lock().await;
*lk += 1;
}
{
let locked = state.encoder_tx.lock().await;
// Force a ws connection to mean a keyframe
let _ = locked.send(EncodeThreadInput::ForceKeyframe).await;
let _ = locked.send(EncodeThreadInput::SendFrame).await;
}
// random username
let username: Arc<String> =
Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16));
println!("{username} ({who}) connected.");
let send_clone = Arc::clone(&state);
let mut send_task = tokio::spawn(async move {
let mut sub = send_clone.websocket_broadcast_tx.subscribe();
while let Ok(msg) = sub.recv().await {
match msg {
WsMessage::VideoPacket { mut packet } => {
// :(. At least this copy doesn't occur on the driver threads anymore..
let data = packet.data_mut().expect("shouldn't be taken");
let msg = ws::Message::Binary(data.to_vec());
if sender.send(msg).await.is_err() {
break;
}
}
WsMessage::Json(s) => {
let msg = ws::Message::Text(s);
if sender.send(msg).await.is_err() {
break;
}
}
}
}
});
let username_clone = Arc::clone(&username);
let recv_clone = Arc::clone(&state);
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(msg) => {
// println!("{}", msg);
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&msg) {
if !json["type"].is_string() {
break;
}
match json["type"].as_str().unwrap() {
"chat" => {
if !json["msg"].is_string() {
break;
}
let send = serde_json::json!({
"type": "chat",
"username": *username_clone,
"msg": json["msg"].as_str().unwrap()
});
recv_clone.websocket_broadcast_tx.send(WsMessage::Json(
serde_json::to_string(&send).expect("oh well"),
));
continue;
}
"key" => {
if !json["keysym"].is_number() {
break;
}
if !json["pressed"].is_number() {
break;
}
let keysym = json["keysym"].as_u64().unwrap() as u32;
let pressed = json["pressed"].as_u64().unwrap() == 1;
// FIXME: This would be MUCH better off being a set, so we don't
// hack-code set semantics here. Oh well.
{
let mut lock = recv_clone.inputs.lock().await;
if pressed {
if let None = lock.iter().position(|e| *e == keysym) {
lock.push(keysym);
}
} else {
if let Some(at) = lock.iter().position(|e| *e == keysym) {
lock.remove(at);
}
}
}
/*let _ = recv_clone
.engine_tx
.send(vnc_engine::VncMessageInput::KeyEvent {
keysym: keysym,
pressed: pressed,
})
.await;*/
}
"mouse" => {
if json["x"].as_u64().is_none() {
break;
}
if json["y"].as_u64().is_none() {
break;
}
if json["mask"].as_u64().is_none() {
break;
}
//let x = json["x"].as_u64().unwrap() as u32;
//let y = json["y"].as_u64().unwrap() as u32;
//let mask = json["mask"].as_u64().unwrap() as u8;
/*let _ = recv_clone
.engine_tx
.send(vnc_engine::VncMessageInput::MouseEvent {
pt: types::Point { x: x, y: y },
buttons: mask,
})
.await;*/
}
_ => {}
}
} else {
break;
}
}
Message::Close(_) => break,
_ => {}
}
}
});
tokio::select! {
_ = (&mut send_task) => {
recv_task.abort();
},
_ = (&mut recv_task) => {
send_task.abort();
}
}
println!("{username} ({who}) left.");
let dec_clone = Arc::clone(&state);
{
let mut lk = dec_clone.websocket_count.lock().await;
*lk -= 1;
}
}

View file

@ -0,0 +1,29 @@
pub mod websocket;
use std::sync::Arc;
use async_trait::async_trait;
#[derive(Clone, Debug)]
pub enum TransportMessage {
Text(String),
Binary(Vec<u8>)
}
#[async_trait]
pub trait Transport : Send {
async fn listen<T: TransportReciever + Send + Sync + 'static>(&self, iface: Arc<T>) -> anyhow::Result<()>;
/// Broadcasts a message.
fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()>;
}
#[async_trait]
pub trait TransportReciever : Send {
async fn on_connect(&self, username: &String) -> anyhow::Result<()>;
async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()>;
async fn on_leave(&self, username: &String) -> anyhow::Result<()>;
}

View file

@ -0,0 +1,154 @@
use super::{Transport, TransportMessage, TransportReciever};
use async_trait::async_trait;
use std::sync::Arc;
use rand::distributions::DistString;
use std::net::SocketAddr;
use tokio::sync::{
broadcast,
mpsc::{self, error::TryRecvError},
Mutex as TokioMutex,
};
use axum::{
extract::{
connect_info::ConnectInfo,
ws::{self, Message, WebSocket, WebSocketUpgrade},
State,
},
response::IntoResponse,
routing::get,
Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
pub struct WebsocketTransport {
broadcast_tx: Arc<broadcast::Sender<TransportMessage>>,
}
async fn ws_handler<T: TransportReciever + Sync + Send + 'static>(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<Arc<T>>,
mut broadcast_rx: broadcast::Receiver<TransportMessage>,
) -> impl IntoResponse {
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, addr, state, broadcast_rx))
}
/// Actual websocket statemachine (one will be spawned per connection)
async fn handle_socket<T: TransportReciever + Sync + Send + 'static>(
socket: WebSocket,
who: SocketAddr,
state: Arc<T>,
mut broadcast_rx: broadcast::Receiver<TransportMessage>,
) {
let (mut sender, mut receiver) = socket.split();
// random username
let username: Arc<String> =
Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16));
let recv_clone = Arc::clone(&state);
state.on_connect(&username).await;
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = broadcast_rx.recv().await {
match msg {
TransportMessage::Text(b) => {
if sender.send(Message::Text(b)).await.is_err() {
break;
}
}
TransportMessage::Binary(b) => {
if sender.send(Message::Binary(b)).await.is_err() {
break;
}
}
}
}
});
let username_clone = username.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(msg) => {
recv_clone.on_message(&username_clone, &msg).await;
}
Message::Close(_) => break,
_ => {}
}
}
});
tokio::select! {
_ = (&mut send_task) => {
recv_task.abort();
},
_ = (&mut recv_task) => {
send_task.abort();
}
}
state.on_leave(&username).await;
}
impl WebsocketTransport {
pub fn new() -> Self {
let (broadcast_tx, _) = broadcast::channel(32);
Self {
broadcast_tx: Arc::new(broadcast_tx),
}
}
}
#[async_trait]
impl Transport for WebsocketTransport {
async fn listen<T: super::TransportReciever + Send + Sync + 'static>(
&self,
iface: Arc<T>,
) -> anyhow::Result<()> {
let tx_clone = self.broadcast_tx.clone();
// Axum websocket server
let app: Router<()> =
Router::new()
.route(
"/",
get(
|ws: WebSocketUpgrade,
info: ConnectInfo<SocketAddr>,
state: State<Arc<T>>| async move {
ws_handler(ws, info, state, tx_clone.subscribe()).await
},
),
)
.with_state(iface.clone());
let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940")
.await
.expect("failed to listen");
let axum_future = axum::serve(
tcp_listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
);
axum_future.await?;
Ok(())
}
fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()> {
let _ = self.broadcast_tx.send(m);
Ok(())
}
}