Skip to content

Commit 82ced00

Browse files
committed
feat(lua): add response header manipulation support for Lua scripts
- Introduce CandyHeaders struct wrapping HeaderMap with Arc<Mutex> - Implement __index/__newindex metamethods for header access in Lua - Support both bracket and dot notation (cd.header["Content-Type"], cd.header.content_type) - Handle multi-value headers via Lua tables (e.g., Set-Cookie) - Auto-convert underscores to hyphens in header names for Lua-friendly syntax
1 parent 913e921 commit 82ced00

1 file changed

Lines changed: 174 additions & 55 deletions

File tree

src/http/lua.rs

Lines changed: 174 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::str::FromStr;
1+
use std::sync::{Arc, Mutex};
22

33
use anyhow::{Context, anyhow};
44
use axum::{
@@ -27,11 +27,116 @@ struct CandyRequest {
2727
/// Uri 在路由中被添加到上下文中
2828
uri: Uri,
2929
}
30+
31+
/// HTTP 响应头包装器,支持 Lua 访问
32+
#[derive(Clone, Debug)]
33+
struct CandyHeaders {
34+
headers: Arc<Mutex<HeaderMap>>,
35+
}
36+
37+
impl CandyHeaders {
38+
fn new(headers: HeaderMap) -> Self {
39+
Self {
40+
headers: Arc::new(Mutex::new(headers)),
41+
}
42+
}
43+
44+
/// 将 Lua 风格的 header 名转换为 HTTP header 名
45+
/// 下划线转换为连字符,如 content_type -> Content-Type
46+
fn normalize_header_name(key: &str) -> String {
47+
key.replace('_', "-")
48+
}
49+
}
50+
51+
impl UserData for CandyHeaders {
52+
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
53+
// __index: 读取 header
54+
// 支持 cd.header["Content-Type"] 和 cd.header.content_type
55+
methods.add_meta_method("__index", |lua, this, key: String| {
56+
let normalized = Self::normalize_header_name(&key);
57+
let headers = this.headers.lock().map_err(|e| {
58+
mlua::Error::external(anyhow!("Failed to lock headers: {}", e))
59+
})?;
60+
61+
// 查找 header (大小写不敏感)
62+
let header_name = HeaderName::try_from(normalized.as_str())
63+
.map_err(|e| mlua::Error::external(anyhow!("Invalid header name: {}", e)))?;
64+
65+
let values: Vec<String> = headers
66+
.get_all(&header_name)
67+
.iter()
68+
.filter_map(|v| v.to_str().ok().map(|s| s.to_string()))
69+
.collect();
70+
71+
if values.is_empty() {
72+
Ok(mlua::Value::Nil)
73+
} else if values.len() == 1 {
74+
Ok(mlua::Value::String(lua.create_string(&values[0])?))
75+
} else {
76+
// 多值 header 返回 table
77+
let table = lua.create_table()?;
78+
for (i, v) in values.iter().enumerate() {
79+
table.set(i + 1, v.clone())?;
80+
}
81+
Ok(mlua::Value::Table(table))
82+
}
83+
});
84+
85+
// __newindex: 设置/删除 header
86+
// cd.header["Content-Type"] = "text/plain"
87+
// cd.header["Set-Cookie"] = {"a=1", "b=2"}
88+
// cd.header["X-My-Header"] = nil -- 删除
89+
methods.add_meta_method_mut("__newindex", |_lua, this, (key, value): (String, mlua::Value)| {
90+
let normalized = Self::normalize_header_name(&key);
91+
let header_name = HeaderName::try_from(normalized.as_str())
92+
.map_err(|e| mlua::Error::external(anyhow!("Invalid header name: {}", e)))?;
93+
94+
let mut headers = this.headers.lock().map_err(|e| {
95+
mlua::Error::external(anyhow!("Failed to lock headers: {}", e))
96+
})?;
97+
98+
// 先移除已有的值
99+
headers.remove(&header_name);
100+
101+
match value {
102+
mlua::Value::Nil => {
103+
// 删除 header,已经 remove 了,不需要额外操作
104+
}
105+
mlua::Value::String(s) => {
106+
let val = s.to_str()?;
107+
let header_value = HeaderValue::from_str(&val)
108+
.map_err(|e| mlua::Error::external(anyhow!("Invalid header value: {}", e)))?;
109+
headers.append(header_name.clone(), header_value);
110+
}
111+
mlua::Value::Table(t) => {
112+
// 多值 header
113+
for pair in t.pairs::<i32, mlua::String>() {
114+
let (_, v) = pair.map_err(|e| {
115+
mlua::Error::external(anyhow!("Invalid header value in table: {}", e))
116+
})?;
117+
let val = v.to_str()?;
118+
let header_value = HeaderValue::from_str(&val)
119+
.map_err(|e| mlua::Error::external(anyhow!("Invalid header value: {}", e)))?;
120+
headers.append(header_name.clone(), header_value);
121+
}
122+
}
123+
_ => {
124+
return Err(mlua::Error::external(anyhow!(
125+
"Header value must be string, table, or nil"
126+
)));
127+
}
128+
}
129+
130+
Ok(())
131+
});
132+
}
133+
}
134+
30135
/// 为 Lua 脚本提供 HTTP 响应上下文
31136
#[derive(Clone, Debug)]
32137
struct CandyResponse {
33138
status: u16,
34-
headers: HeaderMap,
139+
headers: CandyHeaders,
35140
body: String,
36141
}
37142
// HTTP 请求上下文,可在 Lua 中使用
@@ -43,67 +148,72 @@ struct RequestContext {
43148

44149
impl UserData for RequestContext {
45150
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
46-
// 元方法:实现属性访问 (cd.status)
151+
// 元方法:实现属性访问 (cd.status, cd.header)
47152
// 注意:需要同时处理常量字段和动态属性
48-
methods.add_meta_method("__index", |_, this, key: String| {
153+
methods.add_meta_method("__index", |lua, this, key: String| {
49154
match key.as_str() {
50155
// 动态属性
51-
"status" => Ok(this.res.status),
156+
"status" => lua.pack(this.res.status),
157+
"header" => {
158+
// 返回 headers 对象
159+
lua.create_userdata(this.res.headers.clone())
160+
.map(mlua::Value::UserData)
161+
}
52162
// HTTP 方法常量
53-
"HTTP_GET" => Ok(0u16),
54-
"HTTP_HEAD" => Ok(1u16),
55-
"HTTP_PUT" => Ok(2u16),
56-
"HTTP_POST" => Ok(3u16),
57-
"HTTP_DELETE" => Ok(4u16),
58-
"HTTP_OPTIONS" => Ok(5u16),
59-
"HTTP_MKCOL" => Ok(6u16),
60-
"HTTP_COPY" => Ok(7u16),
61-
"HTTP_MOVE" => Ok(8u16),
62-
"HTTP_PROPFIND" => Ok(9u16),
63-
"HTTP_PROPPATCH" => Ok(10u16),
64-
"HTTP_LOCK" => Ok(11u16),
65-
"HTTP_UNLOCK" => Ok(12u16),
66-
"HTTP_PATCH" => Ok(13u16),
67-
"HTTP_TRACE" => Ok(14u16),
163+
"HTTP_GET" => lua.pack(0u16),
164+
"HTTP_HEAD" => lua.pack(1u16),
165+
"HTTP_PUT" => lua.pack(2u16),
166+
"HTTP_POST" => lua.pack(3u16),
167+
"HTTP_DELETE" => lua.pack(4u16),
168+
"HTTP_OPTIONS" => lua.pack(5u16),
169+
"HTTP_MKCOL" => lua.pack(6u16),
170+
"HTTP_COPY" => lua.pack(7u16),
171+
"HTTP_MOVE" => lua.pack(8u16),
172+
"HTTP_PROPFIND" => lua.pack(9u16),
173+
"HTTP_PROPPATCH" => lua.pack(10u16),
174+
"HTTP_LOCK" => lua.pack(11u16),
175+
"HTTP_UNLOCK" => lua.pack(12u16),
176+
"HTTP_PATCH" => lua.pack(13u16),
177+
"HTTP_TRACE" => lua.pack(14u16),
68178
// HTTP 状态码常量 - 1xx
69-
"HTTP_CONTINUE" => Ok(100u16),
70-
"HTTP_SWITCHING_PROTOCOLS" => Ok(101u16),
179+
"HTTP_CONTINUE" => lua.pack(100u16),
180+
"HTTP_SWITCHING_PROTOCOLS" => lua.pack(101u16),
71181
// HTTP 状态码常量 - 2xx
72-
"HTTP_OK" => Ok(200u16),
73-
"HTTP_CREATED" => Ok(201u16),
74-
"HTTP_ACCEPTED" => Ok(202u16),
75-
"HTTP_NO_CONTENT" => Ok(204u16),
76-
"HTTP_PARTIAL_CONTENT" => Ok(206u16),
182+
"HTTP_OK" => lua.pack(200u16),
183+
"HTTP_CREATED" => lua.pack(201u16),
184+
"HTTP_ACCEPTED" => lua.pack(202u16),
185+
"HTTP_NO_CONTENT" => lua.pack(204u16),
186+
"HTTP_PARTIAL_CONTENT" => lua.pack(206u16),
77187
// HTTP 状态码常量 - 3xx
78-
"HTTP_SPECIAL_RESPONSE" => Ok(300u16),
79-
"HTTP_MOVED_PERMANENTLY" => Ok(301u16),
80-
"HTTP_MOVED_TEMPORARILY" => Ok(302u16),
81-
"HTTP_SEE_OTHER" => Ok(303u16),
82-
"HTTP_NOT_MODIFIED" => Ok(304u16),
83-
"HTTP_TEMPORARY_REDIRECT" => Ok(307u16),
188+
"HTTP_SPECIAL_RESPONSE" => lua.pack(300u16),
189+
"HTTP_MOVED_PERMANENTLY" => lua.pack(301u16),
190+
"HTTP_MOVED_TEMPORARILY" => lua.pack(302u16),
191+
"HTTP_SEE_OTHER" => lua.pack(303u16),
192+
"HTTP_NOT_MODIFIED" => lua.pack(304u16),
193+
"HTTP_TEMPORARY_REDIRECT" => lua.pack(307u16),
84194
// HTTP 状态码常量 - 4xx
85-
"HTTP_BAD_REQUEST" => Ok(400u16),
86-
"HTTP_UNAUTHORIZED" => Ok(401u16),
87-
"HTTP_PAYMENT_REQUIRED" => Ok(402u16),
88-
"HTTP_FORBIDDEN" => Ok(403u16),
89-
"HTTP_NOT_FOUND" => Ok(404u16),
90-
"HTTP_NOT_ALLOWED" => Ok(405u16),
91-
"HTTP_NOT_ACCEPTABLE" => Ok(406u16),
92-
"HTTP_REQUEST_TIMEOUT" => Ok(408u16),
93-
"HTTP_CONFLICT" => Ok(409u16),
94-
"HTTP_GONE" => Ok(410u16),
95-
"HTTP_UPGRADE_REQUIRED" => Ok(426u16),
96-
"HTTP_TOO_MANY_REQUESTS" => Ok(429u16),
97-
"HTTP_CLOSE" => Ok(444u16),
98-
"HTTP_ILLEGAL" => Ok(451u16),
195+
"HTTP_BAD_REQUEST" => lua.pack(400u16),
196+
"HTTP_UNAUTHORIZED" => lua.pack(401u16),
197+
"HTTP_PAYMENT_REQUIRED" => lua.pack(402u16),
198+
"HTTP_FORBIDDEN" => lua.pack(403u16),
199+
"HTTP_NOT_FOUND" => lua.pack(404u16),
200+
"HTTP_NOT_ALLOWED" => lua.pack(405u16),
201+
"HTTP_NOT_ACCEPTABLE" => lua.pack(406u16),
202+
"HTTP_REQUEST_TIMEOUT" => lua.pack(408u16),
203+
"HTTP_CONFLICT" => lua.pack(409u16),
204+
"HTTP_GONE" => lua.pack(410u16),
205+
"HTTP_UPGRADE_REQUIRED" => lua.pack(426u16),
206+
"HTTP_TOO_MANY_REQUESTS" => lua.pack(429u16),
207+
"HTTP_CLOSE" => lua.pack(444u16),
208+
"HTTP_ILLEGAL" => lua.pack(451u16),
99209
// HTTP 状态码常量 - 5xx
100-
"HTTP_INTERNAL_SERVER_ERROR" => Ok(500u16),
101-
"HTTP_METHOD_NOT_IMPLEMENTED" => Ok(501u16),
102-
"HTTP_BAD_GATEWAY" => Ok(502u16),
103-
"HTTP_SERVICE_UNAVAILABLE" => Ok(503u16),
104-
"HTTP_GATEWAY_TIMEOUT" => Ok(504u16),
105-
"HTTP_VERSION_NOT_SUPPORTED" => Ok(505u16),
106-
"HTTP_INSUFFICIENT_STORAGE" => Ok(507u16),
210+
"HTTP_INTERNAL_SERVER_ERROR" => lua.pack(500u16),
211+
"HTTP_METHOD_NOT_IMPLEMENTED" => lua.pack(501u16),
212+
"HTTP_BAD_GATEWAY" => lua.pack(502u16),
213+
"HTTP_SERVICE_UNAVAILABLE" => lua.pack(503u16),
214+
"HTTP_GATEWAY_TIMEOUT" => lua.pack(504u16),
215+
"HTTP_VERSION_NOT_SUPPORTED" => lua.pack(505u16),
216+
"HTTP_INSUFFICIENT_STORAGE" => lua.pack(507u16),
107217
_ => Err(mlua::Error::external(anyhow!(
108218
"attempt to index unknown field: {}",
109219
key
@@ -194,7 +304,7 @@ pub async fn lua(
194304
},
195305
res: CandyResponse {
196306
status: 200,
197-
headers: HeaderMap::new(),
307+
headers: CandyHeaders::new(HeaderMap::new()),
198308
body: "".to_string(),
199309
},
200310
},
@@ -217,6 +327,15 @@ pub async fn lua(
217327
let mut response = Response::builder();
218328
let body = Body::from(res.body);
219329
response = response.status(res.status);
330+
331+
// 添加响应头
332+
let headers = response.headers_mut().unwrap();
333+
if let Ok(guard) = res.headers.headers.lock() {
334+
for (name, value) in guard.iter() {
335+
headers.append(name.clone(), value.clone());
336+
}
337+
}
338+
220339
let response = response
221340
.body(body)
222341
.with_context(|| "Failed to build HTTP response with lua")?;

0 commit comments

Comments
 (0)