diff --git a/.changeset/initial_support_for_data_tracks.md b/.changeset/initial_support_for_data_tracks.md new file mode 100644 index 000000000..5ef5593f7 --- /dev/null +++ b/.changeset/initial_support_for_data_tracks.md @@ -0,0 +1,5 @@ +--- +livekit: minor +--- + +# Initial support for data tracks diff --git a/.changeset/rename_to_data_track_stream.md b/.changeset/rename_to_data_track_stream.md new file mode 100644 index 000000000..807bb9b21 --- /dev/null +++ b/.changeset/rename_to_data_track_stream.md @@ -0,0 +1,5 @@ +--- +livekit-datatrack: patch +--- + +# Rename type to DataTrackStream diff --git a/Cargo.lock b/Cargo.lock index bb235cdbb..2c1f9c93b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -679,6 +679,18 @@ dependencies = [ "serde", ] +[[package]] +name = "basic_data_track" +version = "0.1.0" +dependencies = [ + "anyhow", + "env_logger 0.11.10", + "futures-util", + "livekit", + "log", + "tokio", +] + [[package]] name = "basic_room" version = "0.1.0" @@ -3985,6 +3997,7 @@ dependencies = [ "libloading 0.8.9", "libwebrtc", "livekit-api", + "livekit-datatrack", "livekit-protocol", "livekit-runtime", "log", @@ -3993,8 +4006,9 @@ dependencies = [ "semver", "serde", "serde_json", + "test-case", "test-log", - "thiserror 2.0.18", + "thiserror 1.0.69", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 1a4d9a1c5..984f1af28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "examples/agent_dispatch", "examples/api", + "examples/basic_data_track", "examples/basic_room", "examples/basic_text_stream", "examples/encrypted_text_stream", diff --git a/examples/basic_data_track/Cargo.toml b/examples/basic_data_track/Cargo.toml new file mode 100644 index 000000000..3ad66f235 --- /dev/null +++ b/examples/basic_data_track/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "basic_data_track" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +tokio = { workspace = true, features = ["full"] } +livekit = { workspace = true, features = ["rustls-tls-native-roots"] } +futures-util = { workspace = true, default-features = false, features = ["sink"] } +anyhow = { workspace = true } +log = { workspace = true } +env_logger = { workspace = true } + +[[bin]] +name = "publisher" +path = "src/publisher.rs" + +[[bin]] +name = "subscriber" +path = "src/subscriber.rs" \ No newline at end of file diff --git a/examples/basic_data_track/README.md b/examples/basic_data_track/README.md new file mode 100644 index 000000000..537d029ea --- /dev/null +++ b/examples/basic_data_track/README.md @@ -0,0 +1,21 @@ +# Basic Data Track + +Simple example of publishing and subscribing to a data track. + +## Usage + +1. Run the publisher: + +```sh +export LIVEKIT_URL="..." +export LIVEKIT_TOKEN="" +cargo run --bin publisher +``` + +2. In a second terminal, run the subscriber: + +```sh +export LIVEKIT_URL="..." +export LIVEKIT_TOKEN="" +cargo run --bin subscriber +``` diff --git a/examples/basic_data_track/src/publisher.rs b/examples/basic_data_track/src/publisher.rs new file mode 100644 index 000000000..ba43c3d9f --- /dev/null +++ b/examples/basic_data_track/src/publisher.rs @@ -0,0 +1,39 @@ +use anyhow::Result; +use livekit::prelude::*; +use std::{env, time::Duration}; +use tokio::{signal, time}; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + + let url = env::var("LIVEKIT_URL").expect("LIVEKIT_URL is not set"); + let token = env::var("LIVEKIT_TOKEN").expect("LIVEKIT_TOKEN is not set"); + + let (room, _) = Room::connect(&url, &token, RoomOptions::default()).await?; + + let track = room.local_participant().publish_data_track("my_sensor_data").await?; + + tokio::select! { + _ = push_frames(track) => {} + _ = signal::ctrl_c() => {} + } + Ok(()) +} + +async fn read_sensor() -> Vec { + // Dynamically read some sensor data... + vec![0xFA; 256] +} + +async fn push_frames(track: LocalDataTrack) { + loop { + log::info!("Pushing frame"); + + let reading = read_sensor().await; + let frame = DataTrackFrame::new(reading).with_user_timestamp_now(); + + track.try_push(frame).inspect_err(|err| println!("Failed to push frame: {}", err)).ok(); + time::sleep(Duration::from_millis(500)).await + } +} diff --git a/examples/basic_data_track/src/subscriber.rs b/examples/basic_data_track/src/subscriber.rs new file mode 100644 index 000000000..e1107c350 --- /dev/null +++ b/examples/basic_data_track/src/subscriber.rs @@ -0,0 +1,51 @@ +use anyhow::Result; +use futures_util::StreamExt; +use livekit::prelude::*; +use std::env; +use tokio::{signal, sync::mpsc::UnboundedReceiver}; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + + let url = env::var("LIVEKIT_URL").expect("LIVEKIT_URL is not set"); + let token = env::var("LIVEKIT_TOKEN").expect("LIVEKIT_TOKEN is not set"); + + let (_, rx) = Room::connect(&url, &token, RoomOptions::default()).await?; + + tokio::select! { + _ = handle_publications(rx) => {} + _ = signal::ctrl_c() => {} + } + Ok(()) +} + +/// Subscribes to any published data tracks. +async fn handle_publications(mut rx: UnboundedReceiver) -> Result<()> { + while let Some(event) = rx.recv().await { + let RoomEvent::DataTrackPublished(track) = event else { + continue; + }; + subscribe(track).await? + } + Ok(()) +} + +/// Subscribes to the given data track and logs received frames. +async fn subscribe(track: RemoteDataTrack) -> Result<()> { + log::info!( + "Subscribing to '{}' published by '{}'", + track.info().name(), + track.publisher_identity() + ); + let mut subscription = track.subscribe().await?; + while let Some(frame) = subscription.next().await { + log::info!("Received frame ({} bytes)", frame.payload().len()); + + if let Some(duration) = frame.duration_since_timestamp() { + log::info!("Latency: {:?}", duration); + } + } + log::info!("Unsubscribed"); + Ok(()) +} diff --git a/examples/wgpu_room/src/app.rs b/examples/wgpu_room/src/app.rs index 5cb4f905b..cbcfdca65 100644 --- a/examples/wgpu_room/src/app.rs +++ b/examples/wgpu_room/src/app.rs @@ -1,9 +1,10 @@ use crate::{ + data_track::{LocalDataTrackTile, RemoteDataTrackTile, MAX_VALUE, TIME_WINDOW}, service::{AsyncCmd, LkService, UiCmd}, video_grid::VideoGrid, video_renderer::VideoRenderer, }; -use egui::{CornerRadius, Stroke}; +use egui::{emath, epaint, pos2, Color32, CornerRadius, Rect, Stroke}; use livekit::{e2ee::EncryptionType, prelude::*, track::VideoQuality, SimulateScenario}; use std::collections::HashMap; @@ -22,6 +23,8 @@ pub struct LkApp { async_runtime: tokio::runtime::Runtime, state: AppState, video_renderers: HashMap<(ParticipantIdentity, TrackSid), VideoRenderer>, + local_data_tracks: Vec, + remote_data_tracks: Vec, connecting: bool, connection_failure: Option, render_state: egui_wgpu::RenderState, @@ -55,6 +58,8 @@ impl LkApp { async_runtime, state, video_renderers: HashMap::new(), + local_data_tracks: Vec::new(), + remote_data_tracks: Vec::new(), connecting: false, connection_failure: None, render_state: cc.wgpu_render_state.clone().unwrap(), @@ -69,12 +74,17 @@ impl LkApp { self.connection_failure = Some(err.to_string()); } } + UiCmd::DataTrackPublished { track } => { + self.local_data_tracks.push(LocalDataTrackTile::new(track)); + } + UiCmd::DataTrackUnpublished => { + self.local_data_tracks.clear(); + } UiCmd::RoomEvent { event } => { log::info!("{:?}", event); match event { RoomEvent::TrackSubscribed { track, publication: _, participant } => { if let RemoteTrack::Video(ref video_track) = track { - // Create a new VideoRenderer let video_renderer = VideoRenderer::new( self.async_runtime.handle(), self.render_state.clone(), @@ -82,8 +92,6 @@ impl LkApp { ); self.video_renderers .insert((participant.identity(), track.sid()), video_renderer); - } else if let RemoteTrack::Audio(_) = track { - // TODO(theomonnom): Once we support media devices, we can play audio tracks here } } RoomEvent::TrackUnsubscribed { track, publication: _, participant } => { @@ -91,7 +99,6 @@ impl LkApp { } RoomEvent::LocalTrackPublished { track, publication: _, participant } => { if let LocalTrack::Video(ref video_track) = track { - // Also create a new VideoRenderer for local tracks let video_renderer = VideoRenderer::new( self.async_runtime.handle(), self.render_state.clone(), @@ -104,8 +111,14 @@ impl LkApp { RoomEvent::LocalTrackUnpublished { publication, participant } => { self.video_renderers.remove(&(participant.identity(), publication.sid())); } + RoomEvent::DataTrackPublished(track) => { + self.remote_data_tracks + .push(RemoteDataTrackTile::new(self.async_runtime.handle(), track)); + } RoomEvent::Disconnected { reason: _ } => { self.video_renderers.clear(); + self.local_data_tracks.clear(); + self.remote_data_tracks.clear(); } _ => {} } @@ -140,6 +153,9 @@ impl LkApp { if ui.button("SineWave").clicked() { let _ = self.service.send(AsyncCmd::ToggleSine); } + if ui.button("DataTrack").clicked() { + let _ = self.service.send(AsyncCmd::ToggleDataTrack); + } }); ui.menu_button("Debug", |ui| { @@ -319,47 +335,51 @@ impl LkApp { }); } - /// Draw a video grid of all participants + /// Draw a grid of all track tiles (video + data) fn central_panel(&mut self, ui: &mut egui::Ui) { let room = self.service.room(); - let show_videos = self.service.room().is_some(); + let connected = room.is_some(); - if show_videos && self.video_renderers.is_empty() { + let has_tiles = !self.video_renderers.is_empty() + || !self.local_data_tracks.is_empty() + || !self.remote_data_tracks.is_empty(); + + if connected && !has_tiles { ui.centered_and_justified(|ui| { - ui.label("No video tracks subscribed"); + ui.label("No tracks subscribed"); }); return; } egui::ScrollArea::vertical().show(ui, |ui| { VideoGrid::new("default_grid").max_columns(6).show(ui, |ui| { - if show_videos { - // Draw participant videos - for ((participant_sid, _), video_renderer) in &self.video_renderers { - ui.video_frame(|ui| { - let room = room.as_ref().unwrap().clone(); + if connected { + let room = room.as_ref().unwrap(); - if let Some(participant) = - room.remote_participants().get(participant_sid) - { - draw_video( - participant.name().as_str(), - participant.is_speaking(), - video_renderer, - ui, - ); + for ((participant_id, _), video_renderer) in &self.video_renderers { + ui.video_frame(|ui| { + if let Some(p) = room.remote_participants().get(participant_id) { + draw_video(p.name().as_str(), p.is_speaking(), video_renderer, ui); } else { + let lp = room.local_participant(); draw_video( - room.local_participant().name().as_str(), - room.local_participant().is_speaking(), + lp.name().as_str(), + lp.is_speaking(), video_renderer, ui, ); } }); } + + for tile in &mut self.local_data_tracks { + ui.video_frame(|ui| draw_local_data_track(tile, ui)); + } + + for tile in &self.remote_data_tracks { + ui.video_frame(|ui| draw_remote_data_track(tile, ui)); + } } else { - // Draw video skeletons when we're not connected for _ in 0..5 { ui.video_frame(|ui| { egui::Frame::none().fill(ui.style().visuals.code_bg_color).show( @@ -419,7 +439,6 @@ impl eframe::App for LkApp { } } -/// Draw a wgpu texture to the VideoGrid fn draw_video(name: &str, speaking: bool, video_renderer: &VideoRenderer, ui: &mut egui::Ui) { let rect = ui.available_rect_before_wrap(); let inner_rect = rect.shrink(1.0); @@ -434,7 +453,6 @@ fn draw_video(name: &str, speaking: bool, video_renderer: &VideoRenderer, ui: &m ); } - // Always draw a background in case we still didn't receive a frame let resolution = video_renderer.resolution(); if let Some(tex) = video_renderer.texture_id() { ui.painter().image( @@ -453,3 +471,167 @@ fn draw_video(name: &str, speaking: bool, video_renderer: &VideoRenderer, ui: &m egui::Color32::WHITE, ); } + +struct DataTrackChart<'a> { + points: &'a parking_lot::Mutex>, + name: &'a str, + publisher_label: &'a str, + drag_value: Option<&'a mut i32>, +} + +impl<'a> DataTrackChart<'a> { + fn new( + points: &'a parking_lot::Mutex>, + name: &'a str, + publisher_label: &'a str, + ) -> Self { + Self { points, name, publisher_label, drag_value: None } + } + + fn interactive(mut self, value: &'a mut i32) -> Self { + self.drag_value = Some(value); + self + } +} + +impl egui::Widget for DataTrackChart<'_> { + fn ui(self, ui: &mut egui::Ui) -> egui::Response { + let mut drag_value = self.drag_value; + let interactive = drag_value.is_some(); + let sense = if interactive { egui::Sense::click_and_drag() } else { egui::Sense::hover() }; + + let desired_size = ui.available_size(); + let (rect, mut response) = ui.allocate_exact_size(desired_size, sense); + let painter = ui.painter(); + + let bg = Color32::from_rgb(0x1a, 0x1a, 0x2e); + painter.rect_filled(rect, CornerRadius::default(), bg); + + let v_margin = rect.height() * 0.15; + let h_margin = 8.0; + let label_width = 32.0; + let plot_rect = Rect::from_min_max( + pos2(rect.min.x + h_margin, rect.min.y + v_margin), + pos2(rect.max.x - h_margin - label_width, rect.max.y - v_margin), + ); + + let time_window_secs = TIME_WINDOW.as_secs_f32(); + let to_screen = emath::RectTransform::from_to( + Rect::from_x_y_ranges(time_window_secs..=0.0, MAX_VALUE..=0.0), + plot_rect, + ); + + let guide_color = Color32::from_rgb(0x40, 0x40, 0x50); + let max_y = (to_screen * pos2(0.0, MAX_VALUE)).y; + let min_y = (to_screen * pos2(0.0, 0.0)).y; + painter.line_segment( + [pos2(plot_rect.min.x, max_y), pos2(plot_rect.max.x, max_y)], + Stroke::new(1.0, guide_color), + ); + painter.line_segment( + [pos2(plot_rect.min.x, min_y), pos2(plot_rect.max.x, min_y)], + Stroke::new(1.0, guide_color), + ); + + if let Some(value) = &mut drag_value { + if let Some(pointer_pos) = response.interact_pointer_pos() { + let from_screen = to_screen.inverse(); + let logical = from_screen * pointer_pos; + let new_val = (logical.y as i32).clamp(0, MAX_VALUE as i32); + if **value != new_val { + **value = new_val; + response.mark_changed(); + } + } + } + + let now = std::time::Instant::now(); + let mut points = self.points.lock(); + while points.back().is_some_and(|(t, _)| now.duration_since(*t) > TIME_WINDOW) { + points.pop_back(); + } + + let is_interacting = response.interact_pointer_pos().is_some(); + let display_val = drag_value + .as_deref() + .copied() + .filter(|_| !points.is_empty() || is_interacting) + .or_else(|| points.front().map(|(_, v)| *v)); + + let line_color = Color32::from_rgb(0xFF, 0x44, 0x44); + + if !points.is_empty() { + let mut screen_points = Vec::with_capacity(points.len() + 1); + if let Some(val) = display_val { + screen_points.push(to_screen * pos2(0.0, val as f32)); + } + for &(t, val) in points.iter() { + let age = now.duration_since(t).as_secs_f32(); + screen_points.push(to_screen * pos2(age, val as f32)); + } + drop(points); + painter + .add(epaint::Shape::line(screen_points, epaint::PathStroke::new(2.0, line_color))); + ui.ctx().request_repaint(); + } else { + drop(points); + } + + if let Some(val) = display_val { + let newest_screen = to_screen * pos2(0.0, val as f32); + let is_active = interactive && (response.hovered() || response.dragged()); + let dot_radius = if is_active { 6.0 } else { 4.0 }; + painter.circle_filled(newest_screen, dot_radius, line_color); + if is_active { + painter.circle_stroke( + newest_screen, + dot_radius + 2.0, + Stroke::new(1.5, Color32::WHITE), + ); + } + + painter.text( + pos2(plot_rect.max.x + 8.0, newest_screen.y), + egui::Align2::LEFT_CENTER, + val.to_string(), + egui::FontId::monospace(14.0), + Color32::WHITE, + ); + } else { + let hint = if interactive { "Drag to Push Frames…" } else { "Waiting for Frames…" }; + painter.text( + rect.center(), + egui::Align2::CENTER_CENTER, + hint, + egui::FontId::proportional(18.0), + Color32::WHITE, + ); + } + + painter.text( + pos2(rect.min.x + 5.0, rect.max.y - 5.0), + egui::Align2::LEFT_BOTTOM, + format!("Data: {} ({})", self.name, self.publisher_label), + egui::FontId::default(), + Color32::WHITE, + ); + + if interactive { + response = response.on_hover_cursor(egui::CursorIcon::ResizeVertical); + } + + response + } +} + +fn draw_local_data_track(tile: &mut LocalDataTrackTile, ui: &mut egui::Ui) { + let chart = + DataTrackChart::new(&tile.points, &tile.name, "local").interactive(&mut tile.slider_value); + if ui.add(chart).changed() { + tile.push_value(); + } +} + +fn draw_remote_data_track(tile: &RemoteDataTrackTile, ui: &mut egui::Ui) { + ui.add(DataTrackChart::new(&tile.points, &tile.name, &tile.publisher_identity)); +} diff --git a/examples/wgpu_room/src/data_track.rs b/examples/wgpu_room/src/data_track.rs new file mode 100644 index 000000000..8f9ce0852 --- /dev/null +++ b/examples/wgpu_room/src/data_track.rs @@ -0,0 +1,62 @@ +use futures::StreamExt; +use livekit::prelude::*; +use parking_lot::Mutex; +use std::collections::VecDeque; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +pub const TIME_WINDOW: Duration = Duration::from_secs(30); +pub const MAX_VALUE: f32 = 512.0; + +pub struct LocalDataTrackTile { + track: LocalDataTrack, + pub slider_value: i32, + pub points: Arc>>, + pub name: String, +} + +impl LocalDataTrackTile { + pub fn new(track: LocalDataTrack) -> Self { + let name = track.info().name().to_string(); + Self { track, slider_value: 0, points: Arc::new(Mutex::new(VecDeque::new())), name } + } + + pub fn push_value(&self) { + let frame = DataTrackFrame::new(self.slider_value.to_string().into_bytes()); + let _ = self.track.try_push(frame); + self.points.lock().push_front((Instant::now(), self.slider_value)); + } +} + +pub struct RemoteDataTrackTile { + pub points: Arc>>, + pub publisher_identity: String, + pub name: String, +} + +impl RemoteDataTrackTile { + pub fn new(async_handle: &tokio::runtime::Handle, track: RemoteDataTrack) -> Self { + let points = Arc::new(Mutex::new(VecDeque::new())); + let points_ref = points.clone(); + let publisher_identity = track.publisher_identity().to_string(); + let name = track.info().name().to_string(); + + async_handle.spawn(async move { + let mut stream = match track.subscribe().await { + Ok(s) => s, + Err(err) => { + log::error!("Failed to subscribe to data track: {err}"); + return; + } + }; + while let Some(frame) = stream.next().await { + let payload = frame.payload(); + let Ok(s) = std::str::from_utf8(&payload) else { continue }; + let Ok(value) = s.parse::() else { continue }; + points_ref.lock().push_front((Instant::now(), value)); + } + }); + + Self { points, publisher_identity, name } + } +} diff --git a/examples/wgpu_room/src/main.rs b/examples/wgpu_room/src/main.rs index b9c752210..ba621bdfd 100644 --- a/examples/wgpu_room/src/main.rs +++ b/examples/wgpu_room/src/main.rs @@ -5,6 +5,7 @@ use std::thread; use std::time::Duration; mod app; +mod data_track; mod logo_track; mod service; mod sine_track; diff --git a/examples/wgpu_room/src/service.rs b/examples/wgpu_room/src/service.rs index 66976eef2..04ee939f3 100644 --- a/examples/wgpu_room/src/service.rs +++ b/examples/wgpu_room/src/service.rs @@ -19,6 +19,7 @@ pub enum AsyncCmd { SimulateScenario { scenario: SimulateScenario }, ToggleLogo, ToggleSine, + ToggleDataTrack, SubscribeTrack { publication: RemoteTrackPublication }, UnsubscribeTrack { publication: RemoteTrackPublication }, SetVideoQuality { publication: RemoteTrackPublication, quality: VideoQuality }, @@ -30,6 +31,8 @@ pub enum AsyncCmd { pub enum UiCmd { ConnectResult { result: RoomResult<()> }, RoomEvent { event: RoomEvent }, + DataTrackPublished { track: LocalDataTrack }, + DataTrackUnpublished, } /// AppService is the "asynchronous" part of our application, where we connect to a room and @@ -82,6 +85,7 @@ async fn service_task(inner: Arc, mut cmd_rx: mpsc::UnboundedRecei room: Arc, logo_track: LogoTrack, sine_track: SineTrack, + data_track: Option, } let mut running_state = None; @@ -111,6 +115,7 @@ async fn service_task(inner: Arc, mut cmd_rx: mpsc::UnboundedRecei room: new_room.clone(), logo_track: LogoTrack::new(new_room.clone()), sine_track: SineTrack::new(new_room.clone(), SineParameters::default()), + data_track: None, }); // Allow direct access to the room from the UI (Used for sync access) @@ -155,6 +160,24 @@ async fn service_task(inner: Arc, mut cmd_rx: mpsc::UnboundedRecei } } } + AsyncCmd::ToggleDataTrack => { + if let Some(state) = running_state.as_mut() { + if let Some(track) = state.data_track.take() { + track.unpublish(); + let _ = inner.ui_tx.send(UiCmd::DataTrackUnpublished); + } else { + match state.room.local_participant().publish_data_track("slider").await { + Ok(track) => { + let _ = inner + .ui_tx + .send(UiCmd::DataTrackPublished { track: track.clone() }); + state.data_track = Some(track); + } + Err(err) => log::error!("failed to publish data track: {err}"), + } + } + } + } AsyncCmd::SubscribeTrack { publication } => { publication.set_subscribed(true); } diff --git a/livekit-datatrack/src/remote/mod.rs b/livekit-datatrack/src/remote/mod.rs index d3ff69738..edd8845c0 100644 --- a/livekit-datatrack/src/remote/mod.rs +++ b/livekit-datatrack/src/remote/mod.rs @@ -80,7 +80,7 @@ impl DataTrack { /// Note that newly created subscriptions only receive frames published after /// the initial subscription is established. /// - pub async fn subscribe(&self) -> Result { + pub async fn subscribe(&self) -> Result { self.subscribe_with_options(DataTrackSubscribeOptions::default()).await } @@ -92,7 +92,7 @@ impl DataTrack { pub async fn subscribe_with_options( &self, options: DataTrackSubscribeOptions, - ) -> Result { + ) -> Result { let (result_tx, result_rx) = oneshot::channel(); let subscribe_event = SubscribeRequest { sid: self.info.sid(), options, result_tx }; self.inner() @@ -109,7 +109,7 @@ impl DataTrack { .map_err(|_| DataTrackSubscribeError::Timeout)? .map_err(|_| DataTrackSubscribeError::Disconnected)??; - Ok(DataTrackSubscription { inner: BroadcastStream::new(frame_rx) }) + Ok(DataTrackStream { inner: BroadcastStream::new(frame_rx) }) } /// Identity of the participant who published the track. @@ -119,11 +119,11 @@ impl DataTrack { } /// A stream of [`DataTrackFrame`]s received from a [`RemoteDataTrack`]. -pub struct DataTrackSubscription { +pub struct DataTrackStream { inner: BroadcastStream, } -impl Stream for DataTrackSubscription { +impl Stream for DataTrackStream { type Item = DataTrackFrame; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/livekit/Cargo.toml b/livekit/Cargo.toml index c6dae9dec..8a6eef172 100644 --- a/livekit/Cargo.toml +++ b/livekit/Cargo.toml @@ -4,7 +4,7 @@ version = "0.7.35" edition.workspace = true license.workspace = true description = "Rust Client SDK for LiveKit" -repository.workspace = true +repository = "https://github.com/livekit/rust-sdks" [features] # By default ws TLS is not enabled @@ -30,22 +30,24 @@ livekit-runtime = { workspace = true } livekit-api = { workspace = true } libwebrtc = { workspace = true } livekit-protocol = { workspace = true } +livekit-datatrack = { workspace = true } prost = "0.12" -serde = { workspace = true, features = ["derive"] } -serde_json = { workspace = true } -tokio = { workspace = true, default-features = false, features = ["sync", "macros", "fs"] } -parking_lot = { workspace = true } -futures-util = { workspace = true, default-features = false, features = ["sink"] } -thiserror = { workspace = true } -lazy_static = { workspace = true } -log = { workspace = true } +serde = { version = "1", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1", default-features = false, features = ["sync", "macros", "fs"] } +parking_lot = { version = "0.12" } +futures-util = { version = "0.3", default-features = false, features = ["sink"] } +thiserror = "1.0" +lazy_static = "1.4" +log = "0.4" chrono = "0.4.38" semver = "1.0" libloading = { version = "0.8.6" } -bytes = { workspace = true } +bytes = "1.10.1" bmrng = "0.5.2" base64 = "0.22" [dev-dependencies] -anyhow = { workspace = true } +anyhow = "1.0.99" test-log = "0.2.18" +test-case = "3.3" \ No newline at end of file diff --git a/livekit/src/prelude.rs b/livekit/src/prelude.rs index 3e6bda23b..95aff9cfe 100644 --- a/livekit/src/prelude.rs +++ b/livekit/src/prelude.rs @@ -15,6 +15,11 @@ pub use livekit_protocol::AudioTrackFeature; pub use crate::{ + data_track::{ + DataTrackFrame, DataTrackInfo, DataTrackOptions, DataTrackSid, DataTrackSubscribeError, + DataTrackSubscribeOptions, DataTrackStream, LocalDataTrack, PublishError, + PushFrameError, PushFrameErrorReason, RemoteDataTrack, + }, id::*, participant::{ ConnectionQuality, DisconnectReason, LocalParticipant, Participant, PerformRpcData, diff --git a/livekit/src/room/data_track.rs b/livekit/src/room/data_track.rs new file mode 100644 index 000000000..52e83917b --- /dev/null +++ b/livekit/src/room/data_track.rs @@ -0,0 +1,16 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Re-export everything in the "api" module publicly. +pub use livekit_datatrack::api::*; diff --git a/livekit/src/room/e2ee/data_track.rs b/livekit/src/room/e2ee/data_track.rs new file mode 100644 index 000000000..068819ecc --- /dev/null +++ b/livekit/src/room/e2ee/data_track.rs @@ -0,0 +1,85 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::{id::ParticipantIdentity, E2eeManager}; +use bytes::Bytes; +use livekit_datatrack::backend as dt; + +/// Wrapper around [`E2eeManager`] implementing [`dt::EncryptionProvider`]. +#[derive(Debug)] +pub(crate) struct DataTrackEncryptionProvider { + manager: E2eeManager, + sender_identity: ParticipantIdentity, +} + +impl DataTrackEncryptionProvider { + pub fn new(manager: E2eeManager, sender_identity: ParticipantIdentity) -> Self { + Self { manager, sender_identity } + } +} + +impl dt::EncryptionProvider for DataTrackEncryptionProvider { + fn encrypt(&self, payload: bytes::Bytes) -> Result { + let key_index = + self.manager.key_provider().map_or(0, |kp| kp.get_latest_key_index() as u32); + + let encrypted = self + .manager + .encrypt_data(payload.into(), &self.sender_identity, key_index) + .map_err(|_| dt::EncryptionError)?; + + debug_assert_eq!( + encrypted.key_index as u32, + key_index, + "E2EE key index changed during encryption (possible race or inconsistent key selection)" + ); + + let payload = encrypted.data.into(); + let iv = encrypted.iv.try_into().map_err(|_| dt::EncryptionError)?; + let key_index = encrypted.key_index.try_into().map_err(|_| dt::EncryptionError)?; + + Ok(dt::EncryptedPayload { payload, iv, key_index }) + } +} + +/// Wrapper around [`E2eeManager`] implementing [`dt::DecryptionProvider`]. +#[derive(Debug)] +pub(crate) struct DataTrackDecryptionProvider { + manager: E2eeManager, +} + +impl DataTrackDecryptionProvider { + pub fn new(manager: E2eeManager) -> Self { + Self { manager } + } +} + +impl dt::DecryptionProvider for DataTrackDecryptionProvider { + fn decrypt( + &self, + payload: dt::EncryptedPayload, + sender_identity: &str, + ) -> Result { + let decrypted = self + .manager + .decrypt_data( + payload.payload.into(), + payload.iv.to_vec(), + payload.key_index as u32, + sender_identity, + ) + .ok_or_else(|| dt::DecryptionError)?; + Ok(Bytes::from(decrypted)) + } +} diff --git a/livekit/src/room/e2ee/manager.rs b/livekit/src/room/e2ee/manager.rs index 1e583b9c4..e58c55c32 100644 --- a/livekit/src/room/e2ee/manager.rs +++ b/livekit/src/room/e2ee/manager.rs @@ -31,6 +31,7 @@ use crate::{ prelude::{LocalTrack, LocalTrackPublication, RemoteTrack, RemoteTrackPublication}, rtc_engine::lk_runtime::LkRuntime, }; +use std::fmt::Debug; type StateChangedHandler = Box; @@ -246,19 +247,18 @@ impl E2eeManager { } /// Decrypt data received from a data channel - pub fn handle_encrypted_data( + pub(crate) fn decrypt_data( &self, - data: &[u8], - iv: &[u8], - participant_identity: &str, + data: Vec, + iv: Vec, key_index: u32, + participant_identity: &str, ) -> Option> { let inner = self.inner.lock(); let data_packet_cryptor = inner.data_packet_cryptor.as_ref()?; - let encrypted_packet = EncryptedPacket { data: data.to_vec(), iv: iv.to_vec(), key_index }; - + let encrypted_packet = EncryptedPacket { data, iv, key_index }; match data_packet_cryptor.decrypt(participant_identity, &encrypted_packet) { Ok(decrypted_data) => Some(decrypted_data), Err(e) => { @@ -269,10 +269,10 @@ impl E2eeManager { } /// Encrypt data for transmission over a data channel - pub fn encrypt_data( + pub(crate) fn encrypt_data( &self, - data: &[u8], - participant_identity: &str, + data: Vec, + participant_identity: &ParticipantIdentity, key_index: u32, ) -> Result> { let inner = self.inner.lock(); @@ -280,6 +280,14 @@ impl E2eeManager { let data_packet_cryptor = inner.data_packet_cryptor.as_ref().ok_or("DataPacketCryptor is not initialized")?; - data_packet_cryptor.encrypt(participant_identity, key_index, data) + data_packet_cryptor.encrypt(participant_identity.as_str(), key_index, &data) + } +} + +impl Debug for E2eeManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("E2eeManager") + .field("enabled", &self.inner.lock().enabled) + .finish_non_exhaustive() } } diff --git a/livekit/src/room/e2ee/mod.rs b/livekit/src/room/e2ee/mod.rs index 8864dc421..e1235d81d 100644 --- a/livekit/src/room/e2ee/mod.rs +++ b/livekit/src/room/e2ee/mod.rs @@ -19,6 +19,9 @@ use self::key_provider::KeyProvider; pub mod key_provider; pub mod manager; +/// Provider implementations for data track. +pub(crate) mod data_track; + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub enum EncryptionType { #[default] diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index 11d733994..0035237b6 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. use bmrng::unbounded::UnboundedRequestReceiver; +use futures_util::{Stream, StreamExt}; use libwebrtc::{ native::frame_cryptor::EncryptionState, prelude::{ @@ -23,6 +24,10 @@ use libwebrtc::{ RtcError, }; use livekit_api::signal_client::{SignalOptions, SignalSdkOptions, SIGNAL_CONNECT_TIMEOUT}; +use livekit_datatrack::{ + api::{DataTrackSid, RemoteDataTrack}, + backend as dt, +}; use livekit_protocol::observer::Dispatcher; use livekit_protocol::{self as proto, encryption}; use livekit_runtime::JoinHandle; @@ -45,6 +50,7 @@ pub use self::{ }; pub use crate::rtc_engine::SimulateScenario; use crate::{ + e2ee::data_track::{DataTrackDecryptionProvider, DataTrackEncryptionProvider}, participant::ConnectionQuality, prelude::*, registered_audio_filter_plugins, @@ -55,6 +61,7 @@ use crate::{ }; pub mod data_stream; +pub mod data_track; pub mod e2ee; pub mod id; pub mod options; @@ -250,6 +257,10 @@ pub enum RoomEvent { TokenRefreshed { token: String, }, + /// A remote participant published a data track. + DataTrackPublished(RemoteDataTrack), + /// A remote participant has unpublished a data track. + DataTrackUnpublished(DataTrackSid), } #[derive(Debug, Clone, Copy, Eq, PartialEq)] @@ -463,6 +474,8 @@ pub(crate) struct RoomSession { e2ee_manager: E2eeManager, incoming_stream_manager: IncomingStreamManager, outgoing_stream_manager: OutgoingStreamManager, + local_dt_input: dt::local::ManagerInput, + remote_dt_input: dt::remote::ManagerInput, handle: AsyncMutex>, } @@ -470,6 +483,10 @@ struct Handle { room_handle: JoinHandle<()>, incoming_stream_handle: JoinHandle<()>, outgoing_stream_handle: JoinHandle<()>, + local_dt_task: JoinHandle<()>, + local_dt_forward_task: JoinHandle<()>, + remote_dt_task: JoinHandle<()>, + remote_dt_forward_task: JoinHandle<()>, close_tx: broadcast::Sender<()>, } @@ -491,9 +508,11 @@ impl Room { mut options: RoomOptions, ) -> RoomResult<(Self, mpsc::UnboundedReceiver)> { // TODO(theomonnom): move connection logic to the RoomSession + let with_dc_encryption = options.encryption.is_some(); let encryption_options = options.encryption.take().or(options.e2ee.take()); let e2ee_manager = E2eeManager::new(encryption_options, with_dc_encryption); + let mut signal_options = SignalOptions::default(); signal_options.sdk_options = options.sdk_options.clone().into(); signal_options.auto_subscribe = options.auto_subscribe; @@ -618,6 +637,25 @@ impl Room { } }); + let encryption_provider = e2ee_manager.enabled().then(|| { + Arc::new(DataTrackEncryptionProvider::new( + e2ee_manager.clone(), + local_participant.identity().clone(), + )) as Arc + }); + let decryption_provider = e2ee_manager.enabled().then(|| { + Arc::new(DataTrackDecryptionProvider::new(e2ee_manager.clone())) + as Arc + }); + + let local_dt_options = dt::local::ManagerOptions { encryption_provider }; + let (local_dt_manager, local_dt_input, local_dt_output) = + dt::local::Manager::new(local_dt_options); + + let remote_dt_options = dt::remote::ManagerOptions { decryption_provider }; + let (remote_dt_manager, remote_dt_input, remote_dt_output) = + dt::remote::Manager::new(remote_dt_options); + let (incoming_stream_manager, open_rx) = IncomingStreamManager::new(); let (outgoing_stream_manager, packet_rx) = OutgoingStreamManager::new(); @@ -648,6 +686,8 @@ impl Room { e2ee_manager: e2ee_manager.clone(), incoming_stream_manager, outgoing_stream_manager, + local_dt_input, + remote_dt_input, handle: Default::default(), }); inner.local_participant.set_session(Arc::downgrade(&inner)); @@ -724,10 +764,28 @@ impl Room { close_rx.resubscribe(), )); + let local_dt_task = livekit_runtime::spawn(local_dt_manager.run()); + let local_dt_forward_task = livekit_runtime::spawn( + inner.clone().local_dt_forward_task(local_dt_output, close_rx.resubscribe()), + ); + + let remote_dt_task = livekit_runtime::spawn(remote_dt_manager.run()); + let remote_dt_forward_task = livekit_runtime::spawn( + inner.clone().remote_dt_forward_task(remote_dt_output, close_rx.resubscribe()), + ); + let room_handle = livekit_runtime::spawn(inner.clone().room_task(engine_events, close_rx)); - let handle = - Handle { room_handle, incoming_stream_handle, outgoing_stream_handle, close_tx }; + let handle = Handle { + room_handle, + incoming_stream_handle, + outgoing_stream_handle, + local_dt_task, + local_dt_forward_task, + remote_dt_task, + remote_dt_forward_task, + close_tx, + }; inner.handle.lock().await.replace(handle); Ok((Self { inner }, events)) @@ -972,7 +1030,12 @@ impl RoomSession { EngineEvent::TrackMuted { sid, muted } => { self.handle_server_initiated_mute_track(sid, muted); } - _ => {} + EngineEvent::LocalDataTrackInput(event) => { + _ = self.local_dt_input.send(event); + } + EngineEvent::RemoteDataTrackInput(event) => { + _ = self.remote_dt_input.send(event); + } } Ok(()) @@ -992,6 +1055,10 @@ impl RoomSession { let _ = handle.close_tx.send(()); let _ = handle.incoming_stream_handle.await; let _ = handle.outgoing_stream_handle.await; + let _ = handle.local_dt_forward_task.await; + let _ = handle.local_dt_task.await; + let _ = handle.remote_dt_forward_task.await; + let _ = handle.remote_dt_task.await; let _ = handle.room_handle.await; self.dispatcher.clear(); @@ -1322,6 +1389,9 @@ impl RoomSession { }); } + let publish_data_tracks = + dt::local::publish_responses_for_sync_state(self.local_dt_input.query_tracks().await); + let sync_state = proto::SyncState { answer: answer.map(|a| proto::SessionDescription { sdp: a.to_string(), @@ -1344,7 +1414,7 @@ impl RoomSession { publish_tracks: self.local_participant.published_tracks_info(), data_channels: dcs, datachannel_receive_states: session.data_channel_receive_states(), - publish_data_tracks: Default::default(), + publish_data_tracks, }; log::debug!("sending sync state {:?}", sync_state); @@ -1473,6 +1543,12 @@ impl RoomSession { fn handle_restarted(self: &Arc, tx: oneshot::Sender<()>) { let _ = tx.send(()); + // Ensure the SFU knows about existing data track publications. + _ = self.local_dt_input.send(dt::local::InputEvent::RepublishTracks); + + // Ensure SFU continues delivering packets for existing data track subscriptions. + _ = self.remote_dt_input.send(dt::remote::InputEvent::ResendSubscriptionUpdates); + // Unpublish and republish every track // At this time we know that the RtcSession is successfully restarted let published_tracks = self.local_participant.track_publications(); @@ -1926,6 +2002,54 @@ impl RoomSession { let event = RoomEvent::TokenRefreshed { token }; self.dispatcher.dispatch(&event); } + + /// Task for handling output events from the local data track manager. + async fn local_dt_forward_task( + self: Arc, + mut events: impl Stream + Unpin, + mut close_rx: broadcast::Receiver<()>, + ) { + loop { + tokio::select! { + event = events.next() => match event { + Some(event) => _ = self.rtc_engine.handle_local_data_track_output(event).await, + None => break, + }, + _ = close_rx.recv() => { + _ = self.local_dt_input.send(dt::local::InputEvent::Shutdown); + break; + }, + } + } + } + + /// Task for handling output events from the remote data track manager. + async fn remote_dt_forward_task( + self: Arc, + mut events: impl Stream + Unpin, + mut close_rx: broadcast::Receiver<()>, + ) { + loop { + tokio::select! { + event = events.next() => match event { + Some(event) => match event { + dt::remote::OutputEvent::TrackPublished(event) => { + _ = self.dispatcher.dispatch(&RoomEvent::DataTrackPublished(event.track)); + } + dt::remote::OutputEvent::TrackUnpublished(event) => { + _ = self.dispatcher.dispatch(&RoomEvent::DataTrackUnpublished(event.sid)); + } + other => _ = self.rtc_engine.handle_remote_data_track_output(other).await + }, + None => break, + }, + _ = close_rx.recv() => { + _ = self.remote_dt_input.send(dt::remote::InputEvent::Shutdown); + break; + }, + } + } + } } /// Receives stream readers for newly-opened streams and dispatches room events. diff --git a/livekit/src/room/participant/local_participant.rs b/livekit/src/room/participant/local_participant.rs index d5c1ec581..d37015dac 100644 --- a/livekit/src/room/participant/local_participant.rs +++ b/livekit/src/room/participant/local_participant.rs @@ -31,6 +31,7 @@ use crate::{ ByteStreamInfo, ByteStreamWriter, StreamByteOptions, StreamResult, StreamTextOptions, TextStreamInfo, TextStreamWriter, }, + data_track::{self, DataTrack, DataTrackOptions, Local}, e2ee::EncryptionType, options::{self, compute_video_encodings, video_layers_from_encodings, TrackPublishOptions}, prelude::*, @@ -254,6 +255,43 @@ impl LocalParticipant { vec } + /// Publishes a data track. + /// + /// # Returns + /// + /// The published data track if successful. Use [`LocalDataTrack::try_push`] + /// to send data frames on the track. + /// + /// # Examples + /// + /// Publish a track named "my_track": + /// + /// ``` + /// # use livekit::prelude::*; + /// # async fn with_room(room: Room) -> Result<(), PublishError> { + /// let track = room + /// .local_participant() + /// .publish_data_track("my_track") + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// Note: if you are self-hosting the LiveKit SFU and get [`data_track::PublishError::Timeout`], + /// this may indicate you are using an outdated release that does not support data tracks. + /// + pub async fn publish_data_track( + &self, + options: impl Into, + ) -> Result, data_track::PublishError> { + self.session() + .ok_or(PublishError::Disconnected)? + .local_dt_input + .publish_track(options.into()) + .await + } + + /// Publishes a media track. pub async fn publish_track( &self, track: LocalTrack, @@ -520,6 +558,7 @@ impl LocalParticipant { self.inner.rtc_engine.publish_data(packet, kind, true).await.map_err(Into::into) } + /// Publishes a data packet. pub async fn publish_data(&self, packet: DataPacket) -> RoomResult<()> { let kind = match packet.reliable { true => DataPacketKind::Reliable, diff --git a/livekit/src/rtc_engine/dc_sender.rs b/livekit/src/rtc_engine/dc_sender.rs new file mode 100644 index 000000000..c2faedf78 --- /dev/null +++ b/livekit/src/rtc_engine/dc_sender.rs @@ -0,0 +1,172 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use libwebrtc::{self as rtc, data_channel::DataChannel}; +use std::collections::VecDeque; +use tokio::sync::{mpsc, watch}; + +/// Options for constructing a [`DataChannelSender`]. +pub struct DataChannelSenderOptions { + pub low_buffer_threshold: u64, + pub dc: DataChannel, + pub close_rx: watch::Receiver, +} + +/// A task responsible for sending payloads over a single RTC data channel. +/// +/// This implements the same backpressure logic used by `SessionInner::data_channel_task`, +/// but is decoupled from message encoding and retry logic specific to data packets. +/// +/// It was originally introduced to support sending data track packets; however, the +/// implementation is generic and works with arbitrary payloads. +/// +/// In a future refactor, it would be worth revisiting how the logic in +/// `SessionInner::data_channel_task` can be decoupled from session likely by reusing this +/// sender and moving encoding and retry concerns into a separate layer +/// (see the `livekit-datatrack` crate for an example of this approach). +/// +pub struct DataChannelSender { + /// Channel for receiving payloads to be enqueued for sending. + send_rx: mpsc::Receiver, + + /// Channel for receiving events from the data channel. + dc_event_rx: mpsc::UnboundedReceiver, + + dc_event_tx: mpsc::UnboundedSender, + + /// Receiver used to end the task when changed. + close_rx: watch::Receiver, + + /// Reference to the data channel to use for sending. + dc: DataChannel, + + low_buffer_threshold: u64, + + /// Number of bytes in the data channel's internal buffer. + buffered_amount: u64, + + /// Payloads enqueued for sending. + send_queue: VecDeque, +} + +impl DataChannelSender { + /// Buffer size of the channel used to sending events to the task. + const CHANNEL_BUFFER_SIZE: usize = 128; + + /// Creates a new sender. + /// + /// Returns a tuple containing the following: + /// - The sender itself to be spawned by the caller (see [`DataChannelSender::run`]). + /// - Channel for sending payloads over the data channel. + /// + pub fn new(options: DataChannelSenderOptions) -> (Self, mpsc::Sender) { + let (send_tx, send_rx) = mpsc::channel(Self::CHANNEL_BUFFER_SIZE); + let (dc_event_tx, dc_event_rx) = mpsc::unbounded_channel(); + + let sender = Self { + low_buffer_threshold: options.low_buffer_threshold, + dc: options.dc, + send_rx, + dc_event_rx, + dc_event_tx, + close_rx: options.close_rx, + buffered_amount: 0, + send_queue: VecDeque::default(), + }; + (sender, send_tx) + } + + /// Run the sender task, consuming self. + /// + /// The sender will continue running until `close_rx` changes. + /// + pub async fn run(mut self) { + log::debug!("Send task started for data channel '{}'", self.dc.label()); + self.register_dc_callbacks(); + loop { + tokio::select! { + Some(event) = self.dc_event_rx.recv() => { + let DataChannelEvent::BytesSent(bytes_sent) = event; + self.handle_bytes_sent(bytes_sent); + } + Some(payload) = self.send_rx.recv() => { + self.handle_enqueue_for_send(payload) + } + _ = self.close_rx.changed() => break + } + } + + if !self.send_queue.is_empty() { + let unsent_bytes: usize = + self.send_queue.into_iter().map(|payload| payload.len()).sum(); + log::info!("{} byte(s) remain in queue", unsent_bytes); + } + log::debug!("Send task ended for data channel '{}'", self.dc.label()); + } + + fn send_until_threshold(&mut self) { + while self.buffered_amount <= self.low_buffer_threshold { + let Some(payload) = self.send_queue.pop_front() else { + break; + }; + self.buffered_amount += payload.len() as u64; + _ = self + .dc + .send(&payload, true) + .inspect_err(|err| log::error!("Failed to send data: {}", err)); + } + } + + fn handle_enqueue_for_send(&mut self, payload: Bytes) { + self.send_queue.push_back(payload); + self.send_until_threshold(); + } + + fn handle_bytes_sent(&mut self, bytes_sent: u64) { + if self.buffered_amount < bytes_sent { + log::error!("Unexpected buffer amount"); + self.buffered_amount = 0; + return; + } + self.buffered_amount -= bytes_sent; + self.send_until_threshold(); + } + + fn register_dc_callbacks(&self) { + self.dc.on_buffered_amount_change( + on_buffered_amount_change(self.dc_event_tx.downgrade()).into(), + ); + } +} + +/// Event produced by the RTC data channel. +#[derive(Debug)] +enum DataChannelEvent { + /// Indicates the specified number of bytes have been sent. + BytesSent(u64), +} +// Note: if we also need to know when the data channel's state changes, +// we can add an event for that here and register another callback following this same pattern. + +fn on_buffered_amount_change( + event_tx: mpsc::WeakUnboundedSender, +) -> rtc::data_channel::OnBufferedAmountChange { + Box::new(move |bytes_sent| { + // Note: this callback is holding onto a weak sender, which will not + // prevent the channel from closing. + let Some(event_tx) = event_tx.upgrade() else { return }; + _ = event_tx.send(DataChannelEvent::BytesSent(bytes_sent)); + }) +} diff --git a/livekit/src/rtc_engine/mod.rs b/livekit/src/rtc_engine/mod.rs index 85dec231a..1ad21f7de 100644 --- a/livekit/src/rtc_engine/mod.rs +++ b/livekit/src/rtc_engine/mod.rs @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{borrow::Cow, fmt::Debug, sync::Arc, time::Duration}; - use libwebrtc::prelude::*; use livekit_api::signal_client::{SignalError, SignalOptions}; +use livekit_datatrack::backend as dt; use livekit_protocol as proto; use livekit_runtime::{interval, Interval, JoinHandle}; use parking_lot::{RwLock, RwLockReadGuard}; +use std::{borrow::Cow, fmt::Debug, sync::Arc, time::Duration}; use thiserror::Error; use tokio::sync::{ mpsc, oneshot, Mutex as AsyncMutex, Notify, RwLock as AsyncRwLock, @@ -40,6 +40,7 @@ use crate::{ }; use crate::{ChatMessage, E2eeManager, TranscriptionSegment}; +mod dc_sender; pub mod lk_runtime; mod peer_transport; mod rtc_events; @@ -191,6 +192,8 @@ pub enum EngineEvent { sid: String, muted: bool, }, + LocalDataTrackInput(dt::local::InputEvent), + RemoteDataTrackInput(dt::remote::InputEvent), } /// Represents a running RtcSession with the ability to close the session @@ -273,6 +276,30 @@ impl RtcEngine { session.simulate_scenario(scenario).await } + pub async fn handle_local_data_track_output( + &self, + event: dt::local::OutputEvent, + ) -> EngineResult<()> { + let (session, _r_lock) = { + let (handle, _r_lock) = self.inner.wait_reconnection().await?; + (handle.session.clone(), _r_lock) + }; + session.handle_local_data_track_output(event).await; + Ok(()) + } + + pub async fn handle_remote_data_track_output( + &self, + event: dt::remote::OutputEvent, + ) -> EngineResult<()> { + let (session, _r_lock) = { + let (handle, _r_lock) = self.inner.wait_reconnection().await?; + (handle.session.clone(), _r_lock) + }; + session.handle_remote_data_track_output(event).await; + Ok(()) + } + pub async fn add_track(&self, req: proto::AddTrackRequest) -> EngineResult { let (session, _r_lock) = { let (handle, _r_lock) = self.inner.wait_reconnection().await?; @@ -611,6 +638,12 @@ impl EngineInner { SessionEvent::TrackMuted { sid, muted } => { let _ = self.engine_tx.send(EngineEvent::TrackMuted { sid, muted }); } + SessionEvent::LocalDataTrackInput(event) => { + let _ = self.engine_tx.send(EngineEvent::LocalDataTrackInput(event)); + } + SessionEvent::RemoteDataTrackInput(event) => { + let _ = self.engine_tx.send(EngineEvent::RemoteDataTrackInput(event)); + } } Ok(()) } @@ -864,3 +897,9 @@ impl EngineInner { session.wait_pc_connection().await } } + +impl From for EngineError { + fn from(err: livekit_datatrack::api::InternalError) -> Self { + Self::Internal(err.to_string().into()) + } +} diff --git a/livekit/src/rtc_engine/rtc_events.rs b/livekit/src/rtc_engine/rtc_events.rs index d5f59f794..a59aab58f 100644 --- a/livekit/src/rtc_engine/rtc_events.rs +++ b/livekit/src/rtc_engine/rtc_events.rs @@ -18,7 +18,10 @@ use tokio::sync::mpsc; use super::peer_transport::PeerTransport; use crate::{ - rtc_engine::{peer_transport::OnOfferCreated, rtc_session::RELIABLE_DC_LABEL}, + rtc_engine::{ + peer_transport::OnOfferCreated, + rtc_session::{LOSSY_DC_LABEL, RELIABLE_DC_LABEL}, + }, DataPacketKind, }; @@ -94,13 +97,15 @@ fn on_data_channel( emitter: RtcEmitter, ) -> rtc::peer_connection::OnDataChannel { Box::new(move |data_channel| { - let kind = if data_channel.label() == RELIABLE_DC_LABEL { - DataPacketKind::Reliable - } else { - DataPacketKind::Lossy - }; - data_channel.on_message(Some(on_message(emitter.clone(), kind))); - + match data_channel.label().as_str() { + RELIABLE_DC_LABEL => { + data_channel.on_message(Some(on_message(emitter.clone(), DataPacketKind::Reliable))) + } + LOSSY_DC_LABEL => { + data_channel.on_message(Some(on_message(emitter.clone(), DataPacketKind::Lossy))) + } + _ => {} + } let _ = emitter.send(RtcEvent::DataChannel { data_channel, target }); }) } diff --git a/livekit/src/rtc_engine/rtc_session.rs b/livekit/src/rtc_engine/rtc_session.rs index b24696664..3b181fae5 100644 --- a/livekit/src/rtc_engine/rtc_session.rs +++ b/livekit/src/rtc_engine/rtc_session.rs @@ -23,8 +23,10 @@ use std::{ time::Duration, }; +use bytes::Bytes; use libwebrtc::{prelude::*, stats::RtcStats}; use livekit_api::signal_client::{SignalClient, SignalEvent, SignalEvents}; +use livekit_datatrack::backend as dt; use livekit_protocol::{self as proto}; use livekit_runtime::{sleep, JoinHandle}; use parking_lot::Mutex; @@ -34,11 +36,15 @@ use proto::{ SignalTarget, }; use serde::{Deserialize, Serialize}; -use tokio::sync::{mpsc, oneshot, watch, Notify}; +use tokio::sync::{ + mpsc::{self, WeakUnboundedSender}, + oneshot, watch, Notify, +}; use super::{rtc_events, EngineError, EngineOptions, EngineResult, SimulateScenario}; use crate::{ id::ParticipantIdentity, + rtc_engine::dc_sender::{DataChannelSender, DataChannelSenderOptions}, utils::{ ttl_map::TtlMap, tx_queue::{TxQueue, TxQueueItem}, @@ -63,6 +69,7 @@ pub const ICE_CONNECT_TIMEOUT: Duration = Duration::from_secs(15); pub const TRACK_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); pub const LOSSY_DC_LABEL: &str = "_lossy"; pub const RELIABLE_DC_LABEL: &str = "_reliable"; +pub const DATA_TRACK_DC_LABEL: &str = "_data_track"; pub const RELIABLE_RECEIVED_STATE_TTL: Duration = Duration::from_secs(30); pub const PUBLISHER_NEGOTIATION_FREQUENCY: Duration = Duration::from_millis(150); pub const INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD: u64 = 2 * 1024 * 1024; @@ -198,6 +205,8 @@ pub enum SessionEvent { sid: String, muted: bool, }, + LocalDataTrackInput(dt::local::InputEvent), + RemoteDataTrackInput(dt::remote::InputEvent), } #[derive(Debug)] @@ -350,6 +359,7 @@ struct SessionInner { lossy_dc_buffered_amount_low_threshold: AtomicU64, reliable_dc: DataChannel, reliable_dc_buffered_amount_low_threshold: AtomicU64, + data_track_dc: DataChannel, /// Next sequence number for reliable packets. next_packet_sequence: AtomicU32, @@ -365,6 +375,10 @@ struct SessionInner { // so we can receive data from other participants sub_lossy_dc: Mutex>, sub_reliable_dc: Mutex>, + sub_data_track_dc: Mutex>, + + /// Channel for sending data track packets. + dt_packet_tx: mpsc::Sender, closed: AtomicBool, emitter: SessionEmitter, @@ -417,6 +431,7 @@ struct SessionHandle { signal_task: JoinHandle<()>, rtc_task: JoinHandle<()>, dc_task: JoinHandle<()>, + dt_sender_task: JoinHandle<()>, } impl RtcSession { @@ -428,7 +443,7 @@ impl RtcSession { ) -> EngineResult<(Self, proto::JoinResponse, SessionEvents)> { let (emitter, session_events) = mpsc::unbounded_channel(); - let (signal_client, join_response, signal_events) = + let (signal_client, mut join_response, signal_events) = SignalClient::connect(url, token, options.signal_options.clone()).await?; let signal_client = Arc::new(signal_client); log::debug!("received JoinResponse: {:?}", join_response); @@ -440,6 +455,9 @@ impl RtcSession { let Some(participant_info) = SessionParticipantInfo::from_join(&join_response) else { Err(EngineError::Internal("Join response missing participant info".into()))? }; + if let Ok(initial_publications) = dt::remote::event_from_join(&mut join_response) { + _ = emitter.send(SessionEvent::RemoteDataTrackInput(initial_publications.into())); + } let (rtc_emitter, rtc_events) = mpsc::unbounded_channel(); let rtc_config = make_rtc_config_join(join_response.clone(), options.rtc_config.clone()); @@ -464,21 +482,24 @@ impl RtcSession { )) }; - let mut lossy_dc = publisher_pc.peer_connection().create_data_channel( - LOSSY_DC_LABEL, - DataChannelInit { - ordered: false, - max_retransmits: Some(0), - ..DataChannelInit::default() - }, - )?; - let mut reliable_dc = publisher_pc.peer_connection().create_data_channel( RELIABLE_DC_LABEL, // Use ordered: true for reliable delivery with ordering guarantees. - DataChannelInit { ordered: true, ..DataChannelInit::default() }, + DataChannelInit { ordered: true, ..Default::default() }, )?; + let lossy_options = + DataChannelInit { ordered: false, max_retransmits: Some(0), ..Default::default() }; + + let mut lossy_dc = publisher_pc + .peer_connection() + .create_data_channel(LOSSY_DC_LABEL, lossy_options.clone())?; + + let data_track_dc = publisher_pc + .peer_connection() + .create_data_channel(DATA_TRACK_DC_LABEL, lossy_options)?; + handle_remote_dt_packets(&data_track_dc, emitter.downgrade()); + // Forward events received inside the signaling thread to our rtc channel rtc_events::forward_pc_events(&mut publisher_pc, rtc_emitter.clone()); if let Some(ref mut sub_pc) = subscriber_pc { @@ -489,6 +510,13 @@ impl RtcSession { let (close_tx, close_rx) = watch::channel(false); + let dt_sender_options = DataChannelSenderOptions { + low_buffer_threshold: INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD, + dc: data_track_dc.clone(), + close_rx: close_rx.clone(), + }; + let (dt_sender, dt_packet_tx) = DataChannelSender::new(dt_sender_options); + let inner = Arc::new(SessionInner { has_published: Default::default(), fast_publish: AtomicBool::new(join_response.fast_publish), @@ -506,12 +534,15 @@ impl RtcSession { reliable_dc_buffered_amount_low_threshold: AtomicU64::new( INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD, ), + data_track_dc, next_packet_sequence: 1.into(), packet_rx_state: Mutex::new(TtlMap::new(RELIABLE_RECEIVED_STATE_TTL)), participant_info, dc_emitter, sub_lossy_dc: Mutex::new(None), sub_reliable_dc: Mutex::new(None), + sub_data_track_dc: Mutex::new(None), + dt_packet_tx, closed: Default::default(), emitter, options, @@ -527,9 +558,17 @@ impl RtcSession { livekit_runtime::spawn(inner.clone().signal_task(signal_events, close_rx.clone())); let rtc_task = livekit_runtime::spawn(inner.clone().rtc_session_task(rtc_events, close_rx.clone())); - let dc_task = livekit_runtime::spawn(inner.clone().data_channel_task(dc_events, close_rx)); - - let handle = Mutex::new(Some(SessionHandle { close_tx, signal_task, rtc_task, dc_task })); + let dc_task = + livekit_runtime::spawn(inner.clone().data_channel_task(dc_events, close_rx.clone())); + let dt_sender_task = livekit_runtime::spawn(dt_sender.run()); + + let handle = Mutex::new(Some(SessionHandle { + close_tx, + signal_task, + rtc_task, + dc_task, + dt_sender_task, + })); // In single PC mode (or with fast_publish), trigger initial negotiation // This matches JS SDK behavior: if (!this.subscriberPrimary || joinResponse.fastPublish) { this.negotiate(); } @@ -588,6 +627,7 @@ impl RtcSession { let _ = handle.rtc_task.await; let _ = handle.signal_task.await; let _ = handle.dc_task.await; + let _ = handle.dt_sender_task.await; } // Close the PeerConnections after the task @@ -655,6 +695,54 @@ impl RtcSession { self.inner.data_channel(target, kind) } + /// Handles an event from the local data track manager. + pub(super) async fn handle_local_data_track_output(&self, event: dt::local::OutputEvent) { + use dt::local::OutputEvent; + match event { + OutputEvent::SfuPublishRequest(event) => { + if let Err(err) = self.inner.ensure_data_track_publisher_connected().await { + log::error!("Failed to open data track publish transport: {}", err); + } + self.signal_client() + .send(proto::signal_request::Message::PublishDataTrackRequest(event.into())) + .await + } + OutputEvent::SfuUnpublishRequest(event) => { + self.signal_client() + .send(proto::signal_request::Message::UnpublishDataTrackRequest(event.into())) + .await + } + OutputEvent::PacketsAvailable(packets) => self.try_send_data_track_packets(packets), + } + } + + /// Handles an event from the remote data track manager. + pub(super) async fn handle_remote_data_track_output(&self, event: dt::remote::OutputEvent) { + use dt::remote::OutputEvent; + match event.into() { + OutputEvent::SfuUpdateSubscription(event) => { + self.signal_client() + .send(proto::signal_request::Message::UpdateDataSubscription(event.into())) + .await + } + _ => {} + } + } + + /// Try to send data track packets over the transport. + /// + /// Packets will be sent until the first error, at which point the rest of + /// the batch will be dropped. + /// + fn try_send_data_track_packets(&self, packets: Vec) { + for packet in packets { + if self.inner.dt_packet_tx.try_send(packet).is_err() { + log::error!("Failed to enqueue data track packet"); + break; + } + } + } + pub fn e2ee_manager(&self) -> Option { self.inner.e2ee_manager.clone() } @@ -1046,7 +1134,14 @@ impl SessionInner { true, ); } - proto::signal_response::Message::Update(update) => { + proto::signal_response::Message::Update(mut update) => { + let local_participant_identity = self.participant_info.identity.as_str().into(); + if let Ok(event) = dt::remote::event_from_participant_update( + &mut update, + local_participant_identity, + ) { + _ = self.emitter.send(SessionEvent::RemoteDataTrackInput(event.into())); + } let _ = self .emitter .send(SessionEvent::ParticipantUpdate { updates: update.participants }); @@ -1078,11 +1173,29 @@ impl SessionInner { }); } proto::signal_response::Message::RequestResponse(request_response) => { + if let Some(event) = + dt::local::publish_result_from_request_response(&request_response) + { + _ = self.emitter.send(SessionEvent::LocalDataTrackInput(event.into())); + return Ok(()); + } let mut pending_requests = self.pending_requests.lock(); if let Some(tx) = pending_requests.remove(&request_response.request_id) { let _ = tx.send(request_response); } } + proto::signal_response::Message::PublishDataTrackResponse(publish_res) => { + let event: dt::local::SfuPublishResponse = publish_res.try_into()?; + _ = self.emitter.send(SessionEvent::LocalDataTrackInput(event.into())); + } + proto::signal_response::Message::UnpublishDataTrackResponse(unpublish_res) => { + let event: dt::local::SfuUnpublishResponse = unpublish_res.try_into()?; + _ = self.emitter.send(SessionEvent::LocalDataTrackInput(event.into())); + } + proto::signal_response::Message::DataTrackSubscriberHandles(subscriber_handles) => { + let event: dt::remote::SfuSubscriberHandles = subscriber_handles.try_into()?; + _ = self.emitter.send(SessionEvent::RemoteDataTrackInput(event.into())); + } proto::signal_response::Message::RefreshToken(ref token) => { let url = self.signal_client.url(); let _ = self.emitter.send(SessionEvent::RefreshToken { url, token: token.clone() }); @@ -1177,13 +1290,19 @@ impl SessionInner { } else { target == SignalTarget::Subscriber }; - if is_subscriber_dc { - if data_channel.label() == LOSSY_DC_LABEL { - self.sub_lossy_dc.lock().replace(data_channel); - } else if data_channel.label() == RELIABLE_DC_LABEL { - self.sub_reliable_dc.lock().replace(data_channel); - } + if !is_subscriber_dc { + return Ok(()); } + let dc_ref = match data_channel.label().as_str() { + LOSSY_DC_LABEL => &self.sub_lossy_dc, + RELIABLE_DC_LABEL => &self.sub_reliable_dc, + DATA_TRACK_DC_LABEL => { + handle_remote_dt_packets(&data_channel, self.emitter.downgrade()); + &self.sub_data_track_dc + } + _ => return Ok(()), + }; + dc_ref.lock().replace(data_channel); } RtcEvent::Offer { offer, target: _ } => { // Send the publisher offer to the server @@ -1356,14 +1475,15 @@ impl SessionInner { proto::data_packet::Value::EncryptedPacket(encrypted_packet) => { // Handle encrypted data packets if let Some(e2ee_manager) = &self.e2ee_manager { + let encryption_type = encrypted_packet.encryption_type(); let participant_identity_str = participant_identity.as_ref().map(|p| p.0.as_str()).unwrap_or(""); - match e2ee_manager.handle_encrypted_data( - &encrypted_packet.encrypted_value, - &encrypted_packet.iv, - participant_identity_str, + match e2ee_manager.decrypt_data( + encrypted_packet.encrypted_value, + encrypted_packet.iv, encrypted_packet.key_index, + participant_identity_str, ) { Some(decrypted_payload) => { // Parse the decrypted payload as EncryptedPacketPayload @@ -1376,7 +1496,7 @@ impl SessionInner { participant_sid, participant_identity, convert_encrypted_to_data_packet_value(decrypted_value), - encrypted_packet.encryption_type(), + encryption_type, ); Ok(()) } else { @@ -1653,13 +1773,14 @@ impl SessionInner { // Encode the payload and encrypt it let payload_bytes = encrypted_payload.encode_to_vec(); + let key_index = e2ee_manager .key_provider() - .map(|kp| kp.get_latest_key_index() as u32) - .unwrap_or(0); + .map_or(0, |kp| kp.get_latest_key_index() as u32); + match e2ee_manager.encrypt_data( - &payload_bytes, - &self.participant_info.identity.0, + payload_bytes, + &self.participant_info.identity, key_index, ) { Ok(encrypted_data) => { @@ -1893,11 +2014,29 @@ impl SessionInner { } } - /// Ensure the Publisher PC is connected, if not, start the negotiation - /// This is required when sending data to the server + /// Ensure the publisher peer connection and data channel for the specified packet + /// type are connected. If not, start the negotiation. + /// + /// This is required when sending data to the server. + /// async fn ensure_publisher_connected( self: &Arc, kind: DataPacketKind, + ) -> EngineResult<()> { + let required_dc = self.data_channel(SignalTarget::Publisher, kind).unwrap(); + self.ensure_publisher_connected_with_dc(required_dc).await?; + Ok(()) + } + + /// Ensure the required data channel for publishing data track frames is open. + async fn ensure_data_track_publisher_connected(self: &Arc) -> EngineResult<()> { + self.ensure_publisher_connected_with_dc(self.data_track_dc.clone()).await?; + Ok(()) + } + + async fn ensure_publisher_connected_with_dc( + self: &Arc, + required_dc: DataChannel, ) -> EngineResult<()> { if !self.has_published.load(Ordering::Acquire) { // The publisher has never been connected, start the negotiation @@ -1905,14 +2044,14 @@ impl SessionInner { self.publisher_negotiation_needed(); } - let dc = self.data_channel(SignalTarget::Publisher, kind).unwrap(); - if dc.state() == DataChannelState::Open { + if required_dc.state() == DataChannelState::Open { return Ok(()); } // Wait until the PeerConnection is connected let wait_connected = async { - while !self.publisher_pc.is_connected() || dc.state() != DataChannelState::Open { + while !self.publisher_pc.is_connected() || required_dc.state() != DataChannelState::Open + { if self.closed.load(Ordering::Acquire) { return Err(EngineError::Connection("closed".into())); } @@ -1971,6 +2110,20 @@ impl SessionInner { } } +/// Emit incoming data track packets as session events. +pub fn handle_remote_dt_packets(dc: &DataChannel, emitter: WeakUnboundedSender) { + let on_message: libwebrtc::data_channel::OnMessage = Box::new(move |buffer: DataBuffer| { + if !buffer.binary { + log::error!("Received non-binary message"); + return; + } + let packet: Bytes = buffer.data.to_vec().into(); // TODO: avoid clone if possible + let Some(emitter) = emitter.upgrade() else { return }; + _ = emitter.send(SessionEvent::RemoteDataTrackInput(packet.into())); + }); + dc.on_message(on_message.into()); +} + macro_rules! make_rtc_config { ($fncname:ident, $proto:ty) => { fn $fncname(value: $proto, mut config: RtcConfiguration) -> RtcConfiguration { diff --git a/livekit/tests/common/e2e/mod.rs b/livekit/tests/common/e2e/mod.rs index 00af8a307..4f5e70ddb 100644 --- a/livekit/tests/common/e2e/mod.rs +++ b/livekit/tests/common/e2e/mod.rs @@ -45,18 +45,47 @@ impl TestEnvironment { } } +#[derive(Debug, Clone)] +pub struct TestRoomOptions { + /// Grants for the generated token. + pub grants: VideoGrants, + /// Options used for creating the [`Room`]. + pub room: RoomOptions, +} + +impl Default for TestRoomOptions { + fn default() -> Self { + Self { + grants: VideoGrants { room_join: true, ..Default::default() }, + room: Default::default(), + } + } +} + +impl From for TestRoomOptions { + fn from(room: RoomOptions) -> Self { + Self { room, ..Default::default() } + } +} + +impl From for TestRoomOptions { + fn from(grants: VideoGrants) -> Self { + Self { grants, ..Default::default() } + } +} + fn is_local_server_url(url: &str) -> bool { url.contains("localhost:7880") || url.contains("127.0.0.1:7880") } /// Creates the specified number of connections to a shared room for testing. pub async fn test_rooms(count: usize) -> Result)>> { - test_rooms_with_options((0..count).map(|_| RoomOptions::default())).await + test_rooms_with_options((0..count).map(|_| TestRoomOptions::default())).await } /// Creates multiple connections to a shared room for testing, one for each configuration. pub async fn test_rooms_with_options( - options: impl IntoIterator, + options: impl IntoIterator, ) -> Result)>> { let test_env = TestEnvironment::from_env_or_defaults(); let force_v0 = is_local_server_url(&test_env.server_url); @@ -73,18 +102,18 @@ pub async fn test_rooms_with_options( if force_v0 { // Local dev server generally does not support /rtc/v1. Force v0 in generic E2E // tests to avoid extra fallback latency/flakiness. - options.single_peer_connection = false; + options.room.single_peer_connection = false; } - let grants = - VideoGrants { room_join: true, room: room_name.clone(), ..Default::default() }; + options.grants.room = room_name.clone(); + let token = AccessToken::with_api_key(&test_env.api_key, &test_env.api_secret) .with_ttl(Duration::from_secs(30 * 60)) // 30 minutes - .with_grants(grants) + .with_grants(options.grants) .with_identity(&format!("p{}", id)) .with_name(&format!("Participant {}", id)) .to_jwt() .context("Failed to generate JWT")?; - Ok((token, options)) + Ok((token, options.room)) }) .collect::>>()?; diff --git a/livekit/tests/data_channel_encryption.rs b/livekit/tests/data_channel_encryption.rs index 2e1515423..f36207295 100644 --- a/livekit/tests/data_channel_encryption.rs +++ b/livekit/tests/data_channel_encryption.rs @@ -49,7 +49,7 @@ async fn test_data_channel_encryption() -> Result<()> { options2.encryption = Some(E2eeOptions { key_provider: key_provider2, encryption_type: EncryptionType::Gcm }); - let mut rooms = test_rooms_with_options([options1, options2]).await?; + let mut rooms = test_rooms_with_options([options1.into(), options2.into()]).await?; let (sending_room, _) = rooms.pop().unwrap(); let (receiving_room, mut receiving_event_rx) = rooms.pop().unwrap(); diff --git a/livekit/tests/data_track_test.rs b/livekit/tests/data_track_test.rs new file mode 100644 index 000000000..7d762641d --- /dev/null +++ b/livekit/tests/data_track_test.rs @@ -0,0 +1,446 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(feature = "__lk-e2e-test")] +use { + anyhow::{anyhow, Ok, Result}, + common::{test_rooms, test_rooms_with_options, TestRoomOptions}, + futures_util::StreamExt, + livekit::{prelude::*, SimulateScenario}, + livekit_api::access_token::VideoGrants, + std::time::{Duration, Instant}, + test_case::test_case, + tokio::{ + time::{self, timeout}, + try_join, + }, +}; + +mod common; + +#[cfg(feature = "__lk-e2e-test")] +#[test_case(120., 8_192 ; "high_fps_single_packet")] +#[test_case(10., 196_608 ; "low_fps_multi_packet")] +#[test_log::test(tokio::test)] +async fn test_data_track(publish_fps: f64, payload_len: usize) -> Result<()> { + // How long to publish frames for. + const PUBLISH_DURATION: Duration = Duration::from_secs(5); + + // Percentage of total frames that must be received on the subscriber end in + // order for the test to pass. + const MIN_PERCENTAGE: f32 = 0.95; + + let mut rooms = test_rooms(2).await?; + + let (pub_room, _) = rooms.pop().unwrap(); + let (_, mut sub_room_event_rx) = rooms.pop().unwrap(); + let pub_identity = pub_room.local_participant().identity(); + + let frame_count = (PUBLISH_DURATION.as_secs_f64() * publish_fps).round() as u64; + log::info!("Publishing {} frames", frame_count); + + let publish = async move { + let track = pub_room.local_participant().publish_data_track("my_track").await?; + log::info!("Track published"); + + assert!(track.is_published()); + assert!(!track.info().uses_e2ee()); + assert_eq!(track.info().name(), "my_track"); + + let sleep_duration = Duration::from_secs_f64(1.0 / publish_fps as f64); + for index in 0..frame_count { + track.try_push(vec![index as u8; payload_len].into())?; + time::sleep(sleep_duration).await; + } + Ok(()) + }; + + let subscribe = async move { + let track = wait_for_remote_track(&mut sub_room_event_rx).await?; + + log::info!("Got remote track: {}", track.info().sid()); + assert!(track.is_published()); + assert!(!track.info().uses_e2ee()); + assert_eq!(track.info().name(), "my_track"); + assert_eq!(track.publisher_identity(), pub_identity.as_str()); + + let mut subscription = track.subscribe().await?; + + let mut recv_count = 0; + while let Some(frame) = subscription.next().await { + let payload = frame.payload(); + if let Some(first_byte) = payload.first() { + assert!(payload.iter().all(|byte| byte == first_byte)); + } + assert_eq!(frame.user_timestamp(), None); + recv_count += 1; + } + + let recv_percent = recv_count as f32 / frame_count as f32; + log::info!("Received {}/{} frames ({:.2}%)", recv_count, frame_count, recv_percent * 100.); + + if recv_percent < MIN_PERCENTAGE { + Err(anyhow!("Not enough frames received"))?; + } + Ok(()) + }; + timeout(PUBLISH_DURATION + Duration::from_secs(25), async { try_join!(publish, subscribe) }) + .await??; + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_publish_many_tracks() -> Result<()> { + const TRACK_COUNT: usize = 256; + + let (room, _) = test_rooms(1).await?.pop().unwrap(); + + let publish_tracks = async { + let mut tracks = Vec::with_capacity(TRACK_COUNT); + let start = Instant::now(); + + for idx in 0..TRACK_COUNT { + let name = format!("track_{}", idx); + let track = room.local_participant().publish_data_track(name.clone()).await?; + + assert!(track.is_published()); + assert_eq!(track.info().name(), name); + + tracks.push(track); + } + + let elapsed = start.elapsed(); + log::info!( + "Publishing {} tracks took {:.2?} (average {:.2?} per track)", + TRACK_COUNT, + elapsed, + elapsed / TRACK_COUNT as u32 + ); + Ok(tracks) + }; + + let tracks = timeout(Duration::from_secs(5), publish_tracks).await??; + for track in &tracks { + // Publish a single large frame per track. + track.try_push(vec![0xFA; 196_608].into())?; + } + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_publish_unauthorized() -> Result<()> { + let (room, _) = test_rooms_with_options([TestRoomOptions { + grants: VideoGrants { room_join: true, can_publish_data: false, ..Default::default() }, + ..Default::default() + }]) + .await? + .pop() + .unwrap(); + + let result = room.local_participant().publish_data_track("my_track").await; + assert!(matches!(result, Err(PublishError::NotAllowed))); + + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_publish_duplicate_name() -> Result<()> { + let (room, _) = test_rooms(1).await?.pop().unwrap(); + + #[allow(unused)] + let first = room.local_participant().publish_data_track("first").await?; + + let second_result = room.local_participant().publish_data_track("first").await; + assert!(matches!(second_result, Err(PublishError::DuplicateName))); + + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_e2ee() -> Result<()> { + use livekit::e2ee::{ + key_provider::{KeyProvider, KeyProviderOptions}, + EncryptionType, + }; + use livekit::E2eeOptions; + + const SHARED_SECRET: &[u8] = b"password"; + + let key_provider1 = + KeyProvider::with_shared_key(KeyProviderOptions::default(), SHARED_SECRET.to_vec()); + + let mut options1 = RoomOptions::default(); + options1.encryption = + Some(E2eeOptions { key_provider: key_provider1, encryption_type: EncryptionType::Gcm }); + + let key_provider2 = + KeyProvider::with_shared_key(KeyProviderOptions::default(), SHARED_SECRET.to_vec()); + + let mut options2 = RoomOptions::default(); + options2.encryption = + Some(E2eeOptions { key_provider: key_provider2, encryption_type: EncryptionType::Gcm }); + + let mut rooms = test_rooms_with_options([options1.into(), options2.into()]).await?; + + let (pub_room, _) = rooms.pop().unwrap(); + let (sub_room, mut sub_room_event_rx) = rooms.pop().unwrap(); + + pub_room.e2ee_manager().set_enabled(true); + sub_room.e2ee_manager().set_enabled(true); + + let publish = async move { + let track = pub_room.local_participant().publish_data_track("my_track").await?; + assert!(track.info().uses_e2ee()); + + for index in 0..5 { + track.try_push(vec![index as u8; 196_608].into())?; + time::sleep(Duration::from_millis(25)).await; + } + Ok(()) + }; + + let subscribe = async move { + let track = wait_for_remote_track(&mut sub_room_event_rx).await?; + + assert!(track.info().uses_e2ee()); + let mut subscription = track.subscribe().await?; + + while let Some(frame) = subscription.next().await { + let payload = frame.payload(); + if let Some(first_byte) = payload.first() { + assert!(payload.iter().all(|byte| byte == first_byte)); + } + } + Ok(()) + }; + timeout(Duration::from_secs(5), async { try_join!(publish, subscribe) }).await??; + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_published_state() -> Result<()> { + // How long to leave the track published. + const PUBLISH_DURATION: Duration = Duration::from_millis(500); + + let mut rooms = test_rooms(2).await?; + + let (pub_room, _) = rooms.pop().unwrap(); + let (_, mut sub_room_event_rx) = rooms.pop().unwrap(); + + let publish = async move { + let track = pub_room.local_participant().publish_data_track("my_track").await?; + + assert!(track.is_published()); + time::sleep(PUBLISH_DURATION).await; + track.unpublish(); + + Ok(()) + }; + + let subscribe = async move { + let track = wait_for_remote_track(&mut sub_room_event_rx).await?; + assert!(track.is_published()); + + let elapsed = { + let start = Instant::now(); + track.wait_for_unpublish().await; + start.elapsed() + }; + assert!(elapsed.abs_diff(PUBLISH_DURATION) <= Duration::from_millis(20)); + assert!(!track.is_published()); + + Ok(()) + }; + + timeout(Duration::from_secs(5), async { try_join!(publish, subscribe) }).await??; + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_resubscribe() -> Result<()> { + const ITERATIONS: usize = 10; + + let mut rooms = test_rooms(2).await?; + + let (pub_room, _) = rooms.pop().unwrap(); + let (_, mut sub_room_event_rx) = rooms.pop().unwrap(); + + let publish = async move { + let track = pub_room.local_participant().publish_data_track("my_track").await.unwrap(); + loop { + _ = track.try_push(vec![0xFA; 64].into()); + time::sleep(Duration::from_millis(50)).await; + } + }; + + let subscribe = async move { + let track = wait_for_remote_track(&mut sub_room_event_rx).await.unwrap(); + + let mut successful_subscriptions = 0; + for _ in 0..ITERATIONS { + let mut stream = track.subscribe().await.unwrap(); + while let Some(frame) = stream.next().await { + // Ensure we can at least get one frame. + assert!(!frame.payload().is_empty()); + successful_subscriptions += 1; + break; + } + std::mem::drop(stream); + time::sleep(Duration::from_millis(50)).await; + } + assert_eq!(successful_subscriptions, ITERATIONS); + }; + + let _ = timeout(Duration::from_secs(5), async { + tokio::select! { _ = publish => (), _ = subscribe => () }; + }) + .await?; + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_log::test(tokio::test)] +async fn test_frame_with_user_timestamp() -> Result<()> { + let mut rooms = test_rooms(2).await?; + + let (pub_room, _) = rooms.pop().unwrap(); + let (_, mut sub_room_event_rx) = rooms.pop().unwrap(); + + let publish = async move { + let track = pub_room.local_participant().publish_data_track("my_track").await.unwrap(); + loop { + let frame = DataTrackFrame::new(vec![0xFA; 64]).with_user_timestamp_now(); + _ = track.try_push(frame); + time::sleep(Duration::from_millis(50)).await; + } + }; + + let subscribe = async move { + let track = wait_for_remote_track(&mut sub_room_event_rx).await.unwrap(); + + let mut stream = track.subscribe().await.unwrap(); + let mut got_frame = false; + while let Some(frame) = stream.next().await { + // Ensure we can at least get one frame. + assert!(!frame.payload().is_empty()); + let duration = frame.duration_since_timestamp().expect("Missing timestamp"); + assert!(duration.as_millis() < 1000); + got_frame = true; + break; + } + if !got_frame { + panic!("No frame received"); + } + }; + + let _ = timeout(Duration::from_secs(5), async { + tokio::select! { _ = publish => (), _ = subscribe => () }; + }) + .await?; + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_case(SimulateScenario::SignalReconnect; "signal_reconnect")] +#[test_case(SimulateScenario::ForceTcp; "full_reconnect")] +#[test_log::test(tokio::test)] +async fn test_subscriber_side_fault(scenario: SimulateScenario) -> Result<()> { + let mut rooms = test_rooms(2).await?; + + let (pub_room, _) = rooms.pop().unwrap(); + let (sub_room, mut sub_room_event_rx) = rooms.pop().unwrap(); + + let publish = async move { + let track = pub_room.local_participant().publish_data_track("my_track").await.unwrap(); + loop { + _ = track.try_push(vec![0xFA; 64].into()); + time::sleep(Duration::from_millis(50)).await; + } + }; + + let subscribe = async move { + let track = wait_for_remote_track(&mut sub_room_event_rx).await.unwrap(); + let mut stream = track.subscribe().await.unwrap(); + + // TODO: this should also evaluate what happens if a track subscription is removed + // during a full reconnect event. + sub_room.simulate_scenario(scenario).await.unwrap(); + assert!(track.is_published()); + + let mut got_frame = false; + while let Some(frame) = stream.next().await { + // Ensure we can at least get one frame. + assert!(!frame.payload().is_empty()); + got_frame = true; + break; + } + if !got_frame { + panic!("No frame received"); + } + }; + + let _ = timeout(Duration::from_secs(15), async { + tokio::select! { _ = publish => (), _ = subscribe => () }; + }) + .await?; + Ok(()) +} + +#[cfg(feature = "__lk-e2e-test")] +#[test_case(SimulateScenario::SignalReconnect; "signal_reconnect")] +#[test_case(SimulateScenario::ForceTcp; "full_reconnect")] +#[test_log::test(tokio::test)] +async fn test_publisher_side_fault(scenario: SimulateScenario) -> Result<()> { + let mut rooms = test_rooms(1).await?; + let (pub_room, _) = rooms.pop().unwrap(); + + let publish = async move { + let track = pub_room.local_participant().publish_data_track("my_track").await.unwrap(); + let initial_sid = track.info().sid().clone(); + + pub_room.simulate_scenario(scenario).await.unwrap(); + assert!(track.is_published(), "Should still be reported as published"); + + if scenario == SimulateScenario::ForceTcp { + // Give some time for the track to be republished. Frames will be dropped until then. + time::sleep(Duration::from_millis(2000)).await; + assert_ne!(initial_sid, track.info().sid(), "Should have new SID"); + } + + assert!(track.is_published(), "Should still be reported as published"); + track.try_push(vec![0xFA; 64].into()).expect("Should be able to push frame"); + }; + + let _ = timeout(Duration::from_secs(10), publish).await?; + Ok(()) +} + +/// Waits for the first remote data track to be published. +#[cfg(feature = "__lk-e2e-test")] +async fn wait_for_remote_track( + rx: &mut tokio::sync::mpsc::UnboundedReceiver, +) -> Result { + while let Some(event) = rx.recv().await { + if let RoomEvent::DataTrackPublished(track) = event { + return Ok(track); + } + } + Err(anyhow!("No track published")) +} diff --git a/livekit/tests/peer_connection_signaling_test.rs b/livekit/tests/peer_connection_signaling_test.rs index 852715f23..8724f33c7 100644 --- a/livekit/tests/peer_connection_signaling_test.rs +++ b/livekit/tests/peer_connection_signaling_test.rs @@ -61,6 +61,7 @@ use std::time::{Duration, Instant}; use tokio::sync::mpsc::UnboundedReceiver; use tokio::time::timeout; +use crate::common::TestRoomOptions; use crate::common::{ audio::{SineParameters, SineTrack}, test_rooms_with_options, @@ -154,12 +155,12 @@ fn create_token( } /// Create room options for the specified signaling mode -fn room_options(mode: SignalingMode) -> RoomOptions { - let mut options = RoomOptions::default(); - options.auto_subscribe = true; - options.dynacast = false; - options.single_peer_connection = mode.is_single_pc(); - options +fn room_options(mode: SignalingMode) -> TestRoomOptions { + let mut room = RoomOptions::default(); + room.auto_subscribe = true; + room.dynacast = false; + room.single_peer_connection = mode.is_single_pc(); + TestRoomOptions { room, ..Default::default() } } /// Connect to a room with specified signaling mode @@ -170,7 +171,8 @@ async fn connect_room( ) -> Result<(Room, UnboundedReceiver)> { let options = room_options(mode); let started_at = Instant::now(); - let result = Room::connect(url, token, options).await; + // TODO: this should be using the common testing utilities. + let result = Room::connect(url, token, options.room).await; let elapsed = started_at.elapsed(); match &result {