refactor to use a set of traits for transport
(i want to try and make this use webtransport)
This commit is contained in:
parent
8438ca11b5
commit
9248fe91a9
5 changed files with 329 additions and 230 deletions
5
server/Cargo.lock
generated
5
server/Cargo.lock
generated
|
@ -34,9 +34,9 @@ checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da"
|
|||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.82"
|
||||
version = "0.1.83"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1"
|
||||
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -1406,6 +1406,7 @@ name = "vncstream_server"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"cudarc",
|
||||
"ffmpeg-next",
|
||||
|
|
|
@ -31,6 +31,7 @@ cudarc = { version = "0.12.1", features = [ "cuda-11050" ] }
|
|||
tracing = "0.1.40"
|
||||
tracing-subscriber = "0.3.18"
|
||||
xkeysym = "0.2.1"
|
||||
async-trait = "0.1.83"
|
||||
|
||||
|
||||
[patch.crates-io]
|
||||
|
|
|
@ -3,7 +3,12 @@ mod surface;
|
|||
mod types;
|
||||
mod video;
|
||||
|
||||
mod transport;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use retro_thread::{spawn_retro_thread, RetroEvent};
|
||||
use transport::{Transport, TransportReciever};
|
||||
use video::encoder_thread::EncodeThreadInput;
|
||||
use video::{encoder_thread, ffmpeg};
|
||||
|
||||
|
@ -43,22 +48,135 @@ struct AppState {
|
|||
encoder_tx: Arc<TokioMutex<mpsc::Sender<EncodeThreadInput>>>,
|
||||
inputs: Arc<TokioMutex<Vec<u32>>>,
|
||||
|
||||
websocket_broadcast_tx: broadcast::Sender<WsMessage>,
|
||||
websocket_count: TokioMutex<usize>,
|
||||
transport: Arc<crate::transport::websocket::WebsocketTransport>,
|
||||
connection_count: TokioMutex<usize>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
fn new(encoder_tx: mpsc::Sender<EncodeThreadInput>) -> Self {
|
||||
let (broadcast_tx, _) = broadcast::channel(10);
|
||||
fn new(
|
||||
encoder_tx: mpsc::Sender<EncodeThreadInput>,
|
||||
transport: Arc<crate::transport::websocket::WebsocketTransport>,
|
||||
) -> Self {
|
||||
Self {
|
||||
encoder_tx: Arc::new(TokioMutex::new(encoder_tx)),
|
||||
inputs: Arc::new(TokioMutex::new(Vec::new())),
|
||||
websocket_broadcast_tx: broadcast_tx,
|
||||
websocket_count: TokioMutex::const_new(0usize),
|
||||
transport: transport,
|
||||
connection_count: TokioMutex::const_new(0usize),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportReciever for AppState {
|
||||
async fn on_connect(&self, username: &String) -> anyhow::Result<()> {
|
||||
println!("{username} joined!");
|
||||
|
||||
{
|
||||
let mut lk = self.connection_count.lock().await;
|
||||
*lk += 1;
|
||||
}
|
||||
|
||||
{
|
||||
let locked = self.encoder_tx.lock().await;
|
||||
|
||||
// Force a ws connection to mean a keyframe
|
||||
let _ = locked.send(EncodeThreadInput::ForceKeyframe).await;
|
||||
let _ = locked.send(EncodeThreadInput::SendFrame).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()> {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||
if !json["type"].is_string() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match json["type"].as_str().unwrap() {
|
||||
"chat" => {
|
||||
if !json["msg"].is_string() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let send = serde_json::json!({
|
||||
"type": "chat",
|
||||
"username": username,
|
||||
"msg": json["msg"].as_str().unwrap()
|
||||
});
|
||||
|
||||
|
||||
self.transport
|
||||
.broadcast_message(transport::TransportMessage::Text(
|
||||
serde_json::to_string(&send).expect("oh well"),
|
||||
))?;
|
||||
}
|
||||
|
||||
"key" => {
|
||||
if !json["keysym"].is_number() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !json["pressed"].is_number() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let keysym = json["keysym"].as_u64().unwrap() as u32;
|
||||
let pressed = json["pressed"].as_u64().unwrap() == 1;
|
||||
|
||||
// FIXME: This would be MUCH better off being a set, so we don't
|
||||
// hack-code set semantics here. Oh well.
|
||||
{
|
||||
let mut lock = self.inputs.lock().await;
|
||||
if pressed {
|
||||
if let None = lock.iter().position(|e| *e == keysym) {
|
||||
lock.push(keysym);
|
||||
}
|
||||
} else {
|
||||
if let Some(at) = lock.iter().position(|e| *e == keysym) {
|
||||
lock.remove(at);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"mouse" => {
|
||||
if json["x"].as_u64().is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if json["y"].as_u64().is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if json["mask"].as_u64().is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
//let x = json["x"].as_u64().unwrap() as u32;
|
||||
//let y = json["y"].as_u64().unwrap() as u32;
|
||||
//let mask = json["mask"].as_u64().unwrap() as u8;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_leave(&self, username: &String) -> anyhow::Result<()> {
|
||||
{
|
||||
let mut lk = self.connection_count.lock().await;
|
||||
*lk -= 1;
|
||||
}
|
||||
|
||||
println!("{username} left.");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Setup a tracing subscriber
|
||||
|
@ -74,12 +192,16 @@ async fn main() -> anyhow::Result<()> {
|
|||
let frame: Arc<Mutex<Option<ffmpeg::frame::Video>>> = Arc::new(Mutex::new(None));
|
||||
let (mut encoder_rx, encoder_tx) = encoder_thread::encoder_thread_spawn(&frame);
|
||||
|
||||
let state = Arc::new(AppState::new(encoder_tx));
|
||||
let transport = Arc::new(crate::transport::websocket::WebsocketTransport::new());
|
||||
|
||||
let state = Arc::new(AppState::new(encoder_tx, transport.clone()));
|
||||
|
||||
let (mut event_rx, event_in_tx) = spawn_retro_thread(surface.clone());
|
||||
|
||||
let state_clone = state.clone();
|
||||
|
||||
let transport_clone = transport.clone();
|
||||
|
||||
// retro event handler. drives the encoder thread too
|
||||
let _ = std::thread::Builder::new()
|
||||
.name("retro_event_rx".into())
|
||||
|
@ -168,9 +290,18 @@ async fn main() -> anyhow::Result<()> {
|
|||
match encoder_rx.try_recv() {
|
||||
Ok(msg) => match msg {
|
||||
encoder_thread::EncodeThreadOutput::Frame { packet } => {
|
||||
let _ = state_clone
|
||||
.websocket_broadcast_tx
|
||||
.send(WsMessage::VideoPacket { packet });
|
||||
// let _ = state_clone
|
||||
// .websocket_broadcast_tx
|
||||
// .send(WsMessage::VideoPacket { packet });
|
||||
|
||||
// :(
|
||||
let packet_data = {
|
||||
let slice = packet.data().expect(
|
||||
"should NOT be empty, this invariant is checked beforehand",
|
||||
);
|
||||
slice.to_vec()
|
||||
};
|
||||
let _ = transport_clone.broadcast_message(transport::TransportMessage::Binary(packet_data));
|
||||
}
|
||||
},
|
||||
Err(TryRecvError::Empty) => {}
|
||||
|
@ -182,223 +313,6 @@ async fn main() -> anyhow::Result<()> {
|
|||
})
|
||||
.expect("failed to spawn retro RX thread, it's probably over");
|
||||
|
||||
// Axum websocket server
|
||||
let app: Router<()> = Router::new()
|
||||
.route(
|
||||
"/",
|
||||
get(
|
||||
|ws: WebSocketUpgrade,
|
||||
info: ConnectInfo<SocketAddr>,
|
||||
state: State<Arc<AppState>>| async move {
|
||||
ws_handler(ws, info, state).await
|
||||
},
|
||||
),
|
||||
)
|
||||
.with_state(state.clone());
|
||||
|
||||
let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940")
|
||||
.await
|
||||
.expect("failed to listen");
|
||||
|
||||
let axum_future = axum::serve(
|
||||
tcp_listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
);
|
||||
|
||||
// If the VNC client disconnects we should exit.
|
||||
tokio::select! {
|
||||
_ = axum_future => {
|
||||
println!("axum died");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
transport.listen(state).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> impl IntoResponse {
|
||||
// finalize the upgrade process by returning upgrade callback.
|
||||
// we can customize the callback by sending additional info such as address.
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
|
||||
}
|
||||
|
||||
/// Actual websocket statemachine (one will be spawned per connection)
|
||||
async fn handle_socket(socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// increment connection count
|
||||
let inc_clone = Arc::clone(&state);
|
||||
{
|
||||
let mut lk = inc_clone.websocket_count.lock().await;
|
||||
*lk += 1;
|
||||
}
|
||||
|
||||
{
|
||||
let locked = state.encoder_tx.lock().await;
|
||||
|
||||
// Force a ws connection to mean a keyframe
|
||||
let _ = locked.send(EncodeThreadInput::ForceKeyframe).await;
|
||||
let _ = locked.send(EncodeThreadInput::SendFrame).await;
|
||||
}
|
||||
|
||||
// random username
|
||||
let username: Arc<String> =
|
||||
Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16));
|
||||
|
||||
println!("{username} ({who}) connected.");
|
||||
|
||||
let send_clone = Arc::clone(&state);
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
let mut sub = send_clone.websocket_broadcast_tx.subscribe();
|
||||
|
||||
while let Ok(msg) = sub.recv().await {
|
||||
match msg {
|
||||
WsMessage::VideoPacket { mut packet } => {
|
||||
// :(. At least this copy doesn't occur on the driver threads anymore..
|
||||
let data = packet.data_mut().expect("shouldn't be taken");
|
||||
let msg = ws::Message::Binary(data.to_vec());
|
||||
if sender.send(msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
WsMessage::Json(s) => {
|
||||
let msg = ws::Message::Text(s);
|
||||
if sender.send(msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let username_clone = Arc::clone(&username);
|
||||
|
||||
let recv_clone = Arc::clone(&state);
|
||||
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
match msg {
|
||||
Message::Text(msg) => {
|
||||
// println!("{}", msg);
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&msg) {
|
||||
if !json["type"].is_string() {
|
||||
break;
|
||||
}
|
||||
|
||||
match json["type"].as_str().unwrap() {
|
||||
"chat" => {
|
||||
if !json["msg"].is_string() {
|
||||
break;
|
||||
}
|
||||
|
||||
let send = serde_json::json!({
|
||||
"type": "chat",
|
||||
"username": *username_clone,
|
||||
"msg": json["msg"].as_str().unwrap()
|
||||
});
|
||||
|
||||
recv_clone.websocket_broadcast_tx.send(WsMessage::Json(
|
||||
serde_json::to_string(&send).expect("oh well"),
|
||||
));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
"key" => {
|
||||
if !json["keysym"].is_number() {
|
||||
break;
|
||||
}
|
||||
|
||||
if !json["pressed"].is_number() {
|
||||
break;
|
||||
}
|
||||
|
||||
let keysym = json["keysym"].as_u64().unwrap() as u32;
|
||||
let pressed = json["pressed"].as_u64().unwrap() == 1;
|
||||
|
||||
// FIXME: This would be MUCH better off being a set, so we don't
|
||||
// hack-code set semantics here. Oh well.
|
||||
{
|
||||
let mut lock = recv_clone.inputs.lock().await;
|
||||
if pressed {
|
||||
if let None = lock.iter().position(|e| *e == keysym) {
|
||||
lock.push(keysym);
|
||||
}
|
||||
} else {
|
||||
if let Some(at) = lock.iter().position(|e| *e == keysym) {
|
||||
lock.remove(at);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*let _ = recv_clone
|
||||
.engine_tx
|
||||
.send(vnc_engine::VncMessageInput::KeyEvent {
|
||||
keysym: keysym,
|
||||
pressed: pressed,
|
||||
})
|
||||
.await;*/
|
||||
}
|
||||
|
||||
"mouse" => {
|
||||
if json["x"].as_u64().is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
if json["y"].as_u64().is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
if json["mask"].as_u64().is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
//let x = json["x"].as_u64().unwrap() as u32;
|
||||
//let y = json["y"].as_u64().unwrap() as u32;
|
||||
//let mask = json["mask"].as_u64().unwrap() as u8;
|
||||
|
||||
/*let _ = recv_clone
|
||||
.engine_tx
|
||||
.send(vnc_engine::VncMessageInput::MouseEvent {
|
||||
pt: types::Point { x: x, y: y },
|
||||
buttons: mask,
|
||||
})
|
||||
.await;*/
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = (&mut send_task) => {
|
||||
|
||||
recv_task.abort();
|
||||
},
|
||||
|
||||
_ = (&mut recv_task) => {
|
||||
send_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
println!("{username} ({who}) left.");
|
||||
|
||||
let dec_clone = Arc::clone(&state);
|
||||
{
|
||||
let mut lk = dec_clone.websocket_count.lock().await;
|
||||
*lk -= 1;
|
||||
}
|
||||
}
|
||||
|
|
29
server/src/transport/mod.rs
Normal file
29
server/src/transport/mod.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
pub mod websocket;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum TransportMessage {
|
||||
Text(String),
|
||||
Binary(Vec<u8>)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Transport : Send {
|
||||
async fn listen<T: TransportReciever + Send + Sync + 'static>(&self, iface: Arc<T>) -> anyhow::Result<()>;
|
||||
|
||||
/// Broadcasts a message.
|
||||
fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportReciever : Send {
|
||||
async fn on_connect(&self, username: &String) -> anyhow::Result<()>;
|
||||
|
||||
async fn on_message(&self, username: &String, message: &String) -> anyhow::Result<()>;
|
||||
|
||||
async fn on_leave(&self, username: &String) -> anyhow::Result<()>;
|
||||
|
||||
}
|
154
server/src/transport/websocket.rs
Normal file
154
server/src/transport/websocket.rs
Normal file
|
@ -0,0 +1,154 @@
|
|||
use super::{Transport, TransportMessage, TransportReciever};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use rand::distributions::DistString;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::sync::{
|
||||
broadcast,
|
||||
mpsc::{self, error::TryRecvError},
|
||||
Mutex as TokioMutex,
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
connect_info::ConnectInfo,
|
||||
ws::{self, Message, WebSocket, WebSocketUpgrade},
|
||||
State,
|
||||
},
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
|
||||
pub struct WebsocketTransport {
|
||||
broadcast_tx: Arc<broadcast::Sender<TransportMessage>>,
|
||||
}
|
||||
|
||||
async fn ws_handler<T: TransportReciever + Sync + Send + 'static>(
|
||||
ws: WebSocketUpgrade,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
|
||||
State(state): State<Arc<T>>,
|
||||
mut broadcast_rx: broadcast::Receiver<TransportMessage>,
|
||||
) -> impl IntoResponse {
|
||||
// finalize the upgrade process by returning upgrade callback.
|
||||
// we can customize the callback by sending additional info such as address.
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, addr, state, broadcast_rx))
|
||||
}
|
||||
|
||||
/// Actual websocket statemachine (one will be spawned per connection)
|
||||
async fn handle_socket<T: TransportReciever + Sync + Send + 'static>(
|
||||
socket: WebSocket,
|
||||
who: SocketAddr,
|
||||
state: Arc<T>,
|
||||
mut broadcast_rx: broadcast::Receiver<TransportMessage>,
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// random username
|
||||
let username: Arc<String> =
|
||||
Arc::new(rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16));
|
||||
|
||||
let recv_clone = Arc::clone(&state);
|
||||
|
||||
state.on_connect(&username).await;
|
||||
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Ok(msg) = broadcast_rx.recv().await {
|
||||
match msg {
|
||||
TransportMessage::Text(b) => {
|
||||
if sender.send(Message::Text(b)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
TransportMessage::Binary(b) => {
|
||||
if sender.send(Message::Binary(b)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let username_clone = username.clone();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
match msg {
|
||||
Message::Text(msg) => {
|
||||
recv_clone.on_message(&username_clone, &msg).await;
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = (&mut send_task) => {
|
||||
recv_task.abort();
|
||||
},
|
||||
|
||||
_ = (&mut recv_task) => {
|
||||
send_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
state.on_leave(&username).await;
|
||||
}
|
||||
|
||||
impl WebsocketTransport {
|
||||
pub fn new() -> Self {
|
||||
let (broadcast_tx, _) = broadcast::channel(32);
|
||||
|
||||
Self {
|
||||
broadcast_tx: Arc::new(broadcast_tx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for WebsocketTransport {
|
||||
async fn listen<T: super::TransportReciever + Send + Sync + 'static>(
|
||||
&self,
|
||||
iface: Arc<T>,
|
||||
) -> anyhow::Result<()> {
|
||||
let tx_clone = self.broadcast_tx.clone();
|
||||
|
||||
// Axum websocket server
|
||||
let app: Router<()> =
|
||||
Router::new()
|
||||
.route(
|
||||
"/",
|
||||
get(
|
||||
|ws: WebSocketUpgrade,
|
||||
info: ConnectInfo<SocketAddr>,
|
||||
state: State<Arc<T>>| async move {
|
||||
ws_handler(ws, info, state, tx_clone.subscribe()).await
|
||||
},
|
||||
),
|
||||
)
|
||||
.with_state(iface.clone());
|
||||
|
||||
let tcp_listener = tokio::net::TcpListener::bind("0.0.0.0:4940")
|
||||
.await
|
||||
.expect("failed to listen");
|
||||
|
||||
let axum_future = axum::serve(
|
||||
tcp_listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
);
|
||||
|
||||
axum_future.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_message(&self, m: TransportMessage) -> anyhow::Result<()> {
|
||||
let _ = self.broadcast_tx.send(m);
|
||||
Ok(())
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue