diff --git a/src/main.rs b/src/main.rs index d662850..2c394f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,9 @@ use std::net::IpAddr; use std::path::PathBuf; +use std::str::FromStr; use clap::{Parser, Subcommand}; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use tokio::spawn; use tokio_util::sync::CancellationToken; use url::Url; @@ -71,6 +73,11 @@ enum CliCommand { /// - If unset, the recording stops at the last segment of the last media playlist. #[arg(long, allow_hyphen_values = true, verbatim_doc_comment)] end: Option, + /// Custom HTTP header to send with all requests. + /// + /// Can be specified multiple times. Format: "Name: Value" + #[arg(short = 'H', long = "header", value_name = "Name: Value", value_parser = parse_header)] + headers: Vec<(HeaderName, HeaderValue)>, }, /// Replay a HLS VOD or live stream. Replay { @@ -113,6 +120,7 @@ async fn main() { bandwidth, start, end, + headers, } => { let variant_select = if let Some(bandwidth) = bandwidth { VariantSelectOptions::Bandwidth(bandwidth) @@ -126,6 +134,7 @@ async fn main() { audio, video, subtitle, + headers: headers.into_iter().collect::(), }; let token = CancellationToken::new(); let record_task = { @@ -166,3 +175,14 @@ async fn main() { } } } + +fn parse_header(s: &str) -> Result<(HeaderName, HeaderValue), String> { + let (name, value) = s + .split_once(':') + .ok_or_else(|| "invalid header (expected \"Name: Value\")".to_string())?; + + Ok(( + HeaderName::from_str(name.trim()).map_err(|e| e.to_string())?, + HeaderValue::from_str(value.trim()).map_err(|e| e.to_string())?, + )) +} diff --git a/src/record/mod.rs b/src/record/mod.rs index 67086b8..5f76aba 100644 --- a/src/record/mod.rs +++ b/src/record/mod.rs @@ -4,6 +4,7 @@ use futures::future::{BoxFuture, FutureExt}; use futures::stream::{StreamExt, TryStreamExt, iter}; use m3u8_rs::*; use reqwest::Client; +use reqwest::header::HeaderMap; use std::io; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -24,7 +25,7 @@ mod rewrite; const MAX_CONCURRENT_DOWNLOADS: usize = 4; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct RecordOptions { pub variant_select: VariantSelectOptions, pub audio: MediaSelect, @@ -32,6 +33,7 @@ pub struct RecordOptions { pub subtitle: MediaSelect, pub start: Option, pub end: Option, + pub headers: HeaderMap, } #[derive(thiserror::Error, Debug)] @@ -58,6 +60,7 @@ pub async fn record( let recording = Arc::new(Mutex::new(recording)); let client = Client::builder() .cookie_store(true) + .default_headers(options.headers.clone()) .build() .map_err(|_| RecordError::Config("Error while building HTTP client"))?; // Download initial playlist @@ -177,6 +180,7 @@ async fn record_master_playlist( QuotedOrUnquoted::Quoted(variant_url.as_str().to_string()), ); variant.uri = format!("{variant_dir}index.m3u8"); + let options = options.clone(); let token = token.clone(); join_set.spawn(async move { record_media_playlist( @@ -208,6 +212,7 @@ async fn record_master_playlist( QuotedOrUnquoted::Quoted(media_url.as_str().to_string()), ); media.uri = Some(format!("{media_dir}index.m3u8")); + let options = options.clone(); let token = token.clone(); join_set.spawn(async move { record_media_playlist(