Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
344 changes: 342 additions & 2 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ func SetupAuthRoutes(rg *gin.RouterGroup, authService *auth.Service) {
rg.GET("/oauth2/config", authHandler.HandleOAuth2Config)
rg.POST("/oauth2/config", authHandler.HandleOAuth2Config)
rg.DELETE("/oauth2/config", authHandler.HandleOAuth2Config)
// OIDC Discovery
rg.GET("/oauth2/discover", authHandler.HandleOIDCDiscover)
}

// createProxyClient 创建支持系统代理的HTTP客户端
Expand Down Expand Up @@ -400,6 +402,8 @@ func (h *AuthHandler) HandleOAuth2Callback(c *gin.Context) {
h.handleGitHubOAuth(c, code)
case "cloudflare":
h.handleCloudflareOAuth(c, code)
case "custom":
h.handleCustomOIDC(c, code)
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
Expand Down Expand Up @@ -896,7 +900,8 @@ func (h *AuthHandler) HandleOAuth2Login(c *gin.Context) {
q.Set("scope", scopes)
}

if provider == "cloudflare" {
// Cloudflare 和 Custom OIDC 需要设置 response_type=code(OIDC 标准)
if provider == "cloudflare" || provider == "custom" {
q.Set("response_type", "code")
}

Expand All @@ -906,14 +911,349 @@ func (h *AuthHandler) HandleOAuth2Login(c *gin.Context) {
c.Redirect(http.StatusFound, loginURL)
}

// handleCustomOIDC 处理 Custom OIDC 回调
func (h *AuthHandler) handleCustomOIDC(c *gin.Context, code string) {
// 读取配置
cfgStr, err := h.authService.GetSystemConfig("oauth2_config")
if err != nil || cfgStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Custom OIDC 未配置"})
return
}

type customCfg struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthURL string `json:"authUrl"`
TokenURL string `json:"tokenUrl"`
UserInfoURL string `json:"userInfoUrl"`
RedirectURI string `json:"redirectUri"`
Scopes []string `json:"scopes"`
UserIDPath string `json:"userIdPath"`
UsernamePath string `json:"usernamePath"`
DisplayName string `json:"displayName"`
}
var cfg customCfg
_ = json.Unmarshal([]byte(cfgStr), &cfg)
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error from json.Unmarshal is silently ignored. If the configuration JSON is malformed, this could lead to using zero values in the cfg struct, resulting in unclear error messages later. Consider checking the unmarshal error and returning a specific error message indicating the configuration is corrupted or invalid.

Suggested change
_ = json.Unmarshal([]byte(cfgStr), &cfg)
if err := json.Unmarshal([]byte(cfgStr), &cfg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Custom OIDC 配置无效"})
return
}

Copilot uses AI. Check for mistakes.

if cfg.ClientID == "" || cfg.ClientSecret == "" || cfg.TokenURL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Custom OIDC 配置不完整"})
return
}

// 设置默认值
if cfg.UserIDPath == "" {
cfg.UserIDPath = "sub"
}
if cfg.UsernamePath == "" {
cfg.UsernamePath = "preferred_username"
}
if cfg.DisplayName == "" {
cfg.DisplayName = "OIDC"
}

// 交换 access token
form := url.Values{}
form.Set("client_id", cfg.ClientID)
form.Set("client_secret", cfg.ClientSecret)
form.Set("code", code)
form.Set("grant_type", "authorization_code")

// 设置 redirect_uri
redirectURI := cfg.RedirectURI
if redirectURI == "" {
baseURL := fmt.Sprintf("%s://%s", "http", c.Request.Host)
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The redirect URI construction defaults to HTTP when TLS information is not available. This is a security concern because the redirect URI should match the scheme used to access the application. If a user accesses via HTTPS but the redirect URI is constructed as HTTP, it could lead to security issues or OAuth flow failures. Consider defaulting to HTTPS or requiring explicit configuration of the base URL.

Suggested change
baseURL := fmt.Sprintf("%s://%s", "http", c.Request.Host)
// 根据请求信息推断协议,默认使用 HTTPS 以避免降级
scheme := "https"
if proto := c.Request.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else if c.Request.TLS == nil {
// 在本地开发等纯 HTTP 场景下,兼容使用 HTTP
if strings.HasPrefix(c.Request.Host, "localhost") || strings.HasPrefix(c.Request.Host, "127.0.0.1") {
scheme = "http"
}
}
baseURL := fmt.Sprintf("%s://%s", scheme, c.Request.Host)

Copilot uses AI. Check for mistakes.
redirectURI = baseURL + "/api/oauth2/callback"
}
form.Set("redirect_uri", redirectURI)

fmt.Printf("🔍 Custom OIDC Token 请求: token_url=%s, redirect_uri=%s\n", cfg.TokenURL, redirectURI)

tokenReq, _ := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(form.Encode()))
Comment thread
moooyo marked this conversation as resolved.
Outdated
tokenReq.Header.Set("Accept", "application/json")
tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")

// 使用支持代理的HTTP客户端
proxyClient := h.createProxyClient()
resp, err := proxyClient.Do(tokenReq)
if err != nil {
fmt.Printf("❌ Custom OIDC Token 请求错误: %v\n", err)
c.JSON(http.StatusBadGateway, gin.H{"error": "请求 OIDC Token 失败"})
return
}
defer resp.Body.Close()

if resp.StatusCode >= 400 {
bodyBytes, _ := ioutil.ReadAll(resp.Body)
fmt.Printf("❌ Custom OIDC Token 错误 %d: %s\n", resp.StatusCode, string(bodyBytes))
c.JSON(http.StatusBadGateway, gin.H{"error": "OIDC Token 接口返回错误"})
return
}

body, _ := ioutil.ReadAll(resp.Body)
fmt.Printf("🔑 Custom OIDC Token 响应: %s\n", string(body))
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complete token response (which may contain sensitive information like access tokens or id tokens) is logged to stdout. This could expose credentials in log files. Consider logging only metadata like status codes or redacting sensitive fields from the logged response.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback


var tokenRes struct {
AccessToken string `json:"access_token"`
IdToken string `json:"id_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
_ = json.Unmarshal(body, &tokenRes)
Comment thread
moooyo marked this conversation as resolved.
Outdated
if tokenRes.AccessToken == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "获取 AccessToken 失败"})
return
}

var userData map[string]interface{}

// 方式1: 通过 userinfo 端点获取用户信息
if cfg.UserInfoURL != "" {
userReq, _ := http.NewRequest("GET", cfg.UserInfoURL, nil)
userReq.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
userReq.Header.Set("Accept", "application/json")

userResp, err := proxyClient.Do(userReq)
if err == nil {
defer userResp.Body.Close()
bodyBytes, _ := ioutil.ReadAll(userResp.Body)
_ = json.Unmarshal(bodyBytes, &userData)
fmt.Printf("👤 Custom OIDC 用户信息 (userinfo): %s\n", string(bodyBytes))
}
}

// 方式2: 若未获取到用户信息且 id_token 存在,则解析 id_token JWT payload
if len(userData) == 0 && tokenRes.IdToken != "" {
parts := strings.Split(tokenRes.IdToken, ".")
if len(parts) >= 2 {
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
_ = json.Unmarshal(payload, &userData)
Comment thread
moooyo marked this conversation as resolved.
Outdated
fmt.Printf("👤 Custom OIDC id_token payload: %s\n", string(payload))
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using fmt.Printf for logging in production code is not a best practice. These debug statements will always output to stdout regardless of log level configuration. Consider using a proper logging framework (like logrus, zap, or the standard log package) that supports log levels and structured logging, allowing these debug messages to be disabled in production.

Copilot uses AI. Check for mistakes.
}
Comment thread
moooyo marked this conversation as resolved.
Outdated
}

if len(userData) == 0 {
c.JSON(http.StatusBadGateway, gin.H{"error": "无法获取 OIDC 用户信息"})
return
}

// 提取用户 ID(使用配置的 userIdPath)
providerID := h.extractFieldFromUserData(userData, cfg.UserIDPath)
if providerID == "" {
// 回退到常用字段
providerID = h.extractFieldFromUserData(userData, "sub")
if providerID == "" {
providerID = h.extractFieldFromUserData(userData, "id")
}
}

if providerID == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "无法获取 OIDC 用户唯一标识"})
return
}

// 提取用户名(使用配置的 usernamePath)
login := h.extractFieldFromUserData(userData, cfg.UsernamePath)
if login == "" {
// 回退到常用字段
login = h.extractFieldFromUserData(userData, "preferred_username")
if login == "" {
login = h.extractFieldFromUserData(userData, "email")
}
if login == "" {
login = h.extractFieldFromUserData(userData, "name")
}
if login == "" {
login = providerID // 最后回退到使用 providerID
}
}

username := "custom:" + login
Comment thread
moooyo marked this conversation as resolved.

// 保存用户信息
dataJSON, _ := json.Marshal(userData)
if err := h.authService.SaveOAuthUser("custom", providerID, username, string(dataJSON)); err != nil {
fmt.Printf("❌ 保存 Custom OIDC 用户失败: %v\n", err)
// 重定向到错误页面
baseURL := ""
if cfg.RedirectURI != "" {
baseURL = strings.Replace(cfg.RedirectURI, "/api/oauth2/callback", "", 1)
} else {
scheme := "http"
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
scheme = "https"
}
Comment on lines +1204 to +1206
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scheme detection logic defaults to "http" when TLS is not detected. This could be problematic in deployments behind reverse proxies that don't set the X-Forwarded-Proto header. Consider checking additional headers like X-Forwarded-Ssl or Forwarded, or falling back to a safer default based on the deployment environment configuration.

Suggested change
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
scheme = "https"
}
xForwardedProto := strings.ToLower(c.GetHeader("X-Forwarded-Proto"))
xForwardedSsl := strings.ToLower(c.GetHeader("X-Forwarded-Ssl"))
forwarded := c.GetHeader("Forwarded")
if c.Request.TLS != nil || xForwardedProto == "https" || xForwardedSsl == "on" {
scheme = "https"
} else if forwarded != "" {
// Parse the standardized Forwarded header, e.g. "proto=https;host=example.com"
for _, part := range strings.Split(forwarded, ";") {
for _, item := range strings.Split(part, ",") {
item = strings.TrimSpace(strings.ToLower(item))
if strings.HasPrefix(item, "proto=") {
protoVal := strings.TrimPrefix(item, "proto=")
if protoVal == "https" {
scheme = "https"
}
break
}
}
if scheme == "https" {
break
}
}
}

Copilot uses AI. Check for mistakes.
baseURL = fmt.Sprintf("%s://%s", scheme, c.Request.Host)
}
errorURL := fmt.Sprintf("%s/oauth-error?error=%s&provider=custom",
baseURL, url.QueryEscape(err.Error()))
c.Redirect(http.StatusFound, errorURL)
return
}

// 创建会话 (24小时有效期)
sessionID, err := h.authService.CreateSession(username, 24*time.Hour)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建会话失败"})
return
}

// 设置 cookie
c.SetCookie("session", sessionID, 24*60*60, "/", "", false, true)

// 重定向到 dashboard
redirectURL := c.Query("redirect")
if redirectURL == "" {
redirectURL = strings.Replace(cfg.RedirectURI, "/api/oauth2/callback", "/dashboard", 1)
}

accept := c.GetHeader("Accept")
if strings.Contains(accept, "text/html") || strings.Contains(accept, "application/xhtml+xml") || redirectURL != "" {
c.Redirect(http.StatusFound, redirectURL)
return
}

c.JSON(http.StatusOK, gin.H{
"success": true,
"provider": "custom",
"username": username,
"message": "登录成功",
})
}
Comment thread
moooyo marked this conversation as resolved.

// extractFieldFromUserData 从用户数据中提取字段(支持简单的点号路径)
func (h *AuthHandler) extractFieldFromUserData(data map[string]interface{}, path string) string {
if path == "" {
return ""
}

parts := strings.Split(path, ".")
current := data

for i, part := range parts {
if val, ok := current[part]; ok {
if i == len(parts)-1 {
// 最后一个部分,转换为字符串
return fmt.Sprintf("%v", val)
Comment on lines +1257 to +1258
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extractFieldFromUserData function uses fmt.Sprintf to convert the final value to a string, which will convert any type to its string representation. However, this might not be appropriate for all data types. For example, if the field contains a boolean true, it will become the string "true". Consider validating that the extracted value is actually a string type, or document this behavior clearly.

Suggested change
// 最后一个部分,转换为字符串
return fmt.Sprintf("%v", val)
// 最后一个部分,仅当值为字符串时返回
if s, ok := val.(string); ok {
return s
}
return ""

Copilot uses AI. Check for mistakes.
}
// 不是最后一个部分,继续深入
if nested, ok := val.(map[string]interface{}); ok {
current = nested
} else {
return ""
}
} else {
return ""
}
}
return ""
}

// HandleOAuth2Provider 仅返回当前绑定的 OAuth2 provider(用于登录页)
func (h *AuthHandler) HandleOAuth2Provider(c *gin.Context) {
provider, _ := h.authService.GetSystemConfig("oauth2_provider")
disableLogin, _ := h.authService.GetSystemConfig("disable_login")

c.JSON(http.StatusOK, gin.H{
resp := gin.H{
"success": true,
"provider": provider,
"disableLogin": disableLogin == "true",
}

// 如果是 custom provider,返回 displayName
if provider == "custom" {
cfgStr, _ := h.authService.GetSystemConfig("oauth2_config")
if cfgStr != "" {
var cfg map[string]interface{}
_ = json.Unmarshal([]byte(cfgStr), &cfg)
if displayName, ok := cfg["displayName"].(string); ok && displayName != "" {
Comment thread
moooyo marked this conversation as resolved.
Outdated
resp["displayName"] = displayName
}
}
}

c.JSON(http.StatusOK, resp)
}

// HandleOIDCDiscover 处理 OIDC Discovery 请求
// GET /api/oauth2/discover?url=https://auth.example.com/.well-known/openid-configuration
func (h *AuthHandler) HandleOIDCDiscover(c *gin.Context) {
discoveryURL := c.Query("url")
if discoveryURL == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "缺少 url 参数",
})
return
}

Comment thread
moooyo marked this conversation as resolved.
Outdated
// 使用支持代理的 HTTP 客户端
proxyClient := h.createProxyClient()

req, err := http.NewRequest("GET", discoveryURL, nil)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "无效的 Discovery URL",
})
return
}
req.Header.Set("Accept", "application/json")

resp, err := proxyClient.Do(req)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"success": false,
"error": fmt.Sprintf("无法连接到 OIDC 服务器: %v", err),
})
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
c.JSON(http.StatusBadGateway, gin.H{
"success": false,
"error": fmt.Sprintf("OIDC Discovery 失败,状态码: %d", resp.StatusCode),
})
return
}

body, err := ioutil.ReadAll(resp.Body)
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deprecated ioutil.ReadAll is used instead of the recommended io.ReadAll. Since Go 1.16, io.ReadAll should be used as ioutil functions have been deprecated and moved to the io package.

Copilot uses AI. Check for mistakes.
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": "读取响应失败",
})
return
}

var discoveryData map[string]interface{}
if err := json.Unmarshal(body, &discoveryData); err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"success": false,
"error": "解析 OIDC 配置失败",
})
return
}

// 提取关键端点
authorizationEndpoint, _ := discoveryData["authorization_endpoint"].(string)
tokenEndpoint, _ := discoveryData["token_endpoint"].(string)
userinfoEndpoint, _ := discoveryData["userinfo_endpoint"].(string)
issuer, _ := discoveryData["issuer"].(string)

if authorizationEndpoint == "" || tokenEndpoint == "" {
c.JSON(http.StatusBadGateway, gin.H{
"success": false,
"error": "OIDC 配置不完整,缺少必要端点",
})
return
}

c.JSON(http.StatusOK, gin.H{
"success": true,
"issuer": issuer,
"authorizationEndpoint": authorizationEndpoint,
"tokenEndpoint": tokenEndpoint,
"userinfoEndpoint": userinfoEndpoint,
})
}
Loading