diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 78796c49..b22ddc54 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -209,6 +209,26 @@ func TestOIDCController(t *testing.T) { }, // --- authorize-complete --- + { + description: "Should fail if oidc is disabled", + oidcDisabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"}) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + + var res map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error") + }, + }, { description: "Authorize complete returns a JSON error when the user context is missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 891ff59b..ffafaffd 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - c.Redirect(http.StatusTemporaryRedirect, redirectURL) + c.Redirect(http.StatusFound, redirectURL) return } @@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - c.Redirect(http.StatusTemporaryRedirect, redirectURL) + c.Redirect(http.StatusFound, redirectURL) return } @@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - c.Redirect(http.StatusTemporaryRedirect, redirectURL) + c.Redirect(http.StatusFound, redirectURL) return } } @@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - c.Redirect(http.StatusTemporaryRedirect, redirectURL) + c.Redirect(http.StatusFound, redirectURL) } func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { @@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon return } - c.Redirect(http.StatusTemporaryRedirect, redirectURL) + c.Redirect(http.StatusFound, redirectURL) } func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 1b2eb5c6..faa9934b 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -2,6 +2,9 @@ package controller import ( "context" + "encoding/base64" + "fmt" + "net/http" "net/http/httptest" "net/url" "testing" @@ -63,6 +66,17 @@ func TestProxyController(t *testing.T) { } tests := []testCase{ + { + description: "Should get bad request on invalid proxy", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/invalid", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Bad request") + }, + }, { description: "Default forward auth should be detected and used for traefik", middlewares: []gin.HandlerFunc{}, @@ -74,7 +88,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 307, recorder.Code) + assert.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, "login_for=app") @@ -89,7 +103,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) location := recorder.Header().Get("x-tinyauth-location") assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, "login_for=app") @@ -105,7 +119,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-proto", "https") req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 307, recorder.Code) + assert.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello")) assert.Contains(t, location, "login_for=app") @@ -123,7 +137,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 307, recorder.Code) + assert.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, "login_for=app") @@ -140,7 +154,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-uri", "/") req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) location := recorder.Header().Get("x-tinyauth-location") assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, "login_for=app") @@ -158,7 +172,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 307, recorder.Code) + assert.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, "login_for=app") @@ -175,7 +189,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-uri", "/") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) }, @@ -190,7 +204,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-uri", "/") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) }, @@ -205,7 +219,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-uri", "/hello") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) }, @@ -222,7 +236,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-uri", "/") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) @@ -238,7 +252,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-original-url", "https://test.example.com/") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) @@ -255,7 +269,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-proto", "https") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) @@ -270,7 +284,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-uri", "/allowed") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) }, }, { @@ -280,7 +294,7 @@ func TestProxyController(t *testing.T) { req := httptest.NewRequest("GET", "/api/auth/nginx", nil) req.Header.Set("x-original-url", "https://path-allow.example.com/allowed") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) }, }, { @@ -291,7 +305,7 @@ func TestProxyController(t *testing.T) { req.Host = "path-allow.example.com" req.Header.Set("x-forwarded-proto", "https") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) }, }, { @@ -304,7 +318,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-for", "10.10.10.10") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) }, }, { @@ -315,7 +329,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-original-url", "https://ip-bypass.example.com/") req.Header.Set("x-forwarded-for", "10.10.10.10") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) }, }, { @@ -327,7 +341,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-for", "10.10.10.10") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) }, }, { @@ -341,7 +355,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-uri", "/") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) }, }, { @@ -355,12 +369,301 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-uri", "/") router.ServeHTTP(recorder, req) - assert.Equal(t, 403, recorder.Code) + assert.Equal(t, http.StatusForbidden, recorder.Code) assert.Equal(t, "", recorder.Header().Get("remote-user")) assert.Equal(t, "", recorder.Header().Get("remote-name")) assert.Equal(t, "", recorder.Header().Get("remote-email")) }, }, + { + description: "Test IP block rule, with non browser user agent", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "ip-block.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("x-forwarded-for", "10.10.10.10") + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusForbidden, recorder.Code) + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL) + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10") + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block") + + }, + }, + { + description: "Test IP block rule, with browser user agent", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "ip-block.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("x-forwarded-for", "10.10.10.10") + req.Header.Set("user-agent", browserUserAgent) + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, url.QueryEscape("10.10.10.10")) + assert.Contains(t, location, url.QueryEscape("ip-block")) + assert.Contains(t, location, runtime.AppURL) + }, + }, + { + description: "OAuth allowed group", + middlewares: []gin.HandlerFunc{ + func(ctx *gin.Context) { + ctx.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + Groups: []string{"group1"}, + }, + }) + ctx.Next() + }, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "oauth-group.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) + assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) + assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) + assert.Equal(t, "group1", recorder.Header().Get("remote-groups")) + }, + }, + { + description: "OAuth not in required groups and non browser", + middlewares: []gin.HandlerFunc{ + func(ctx *gin.Context) { + ctx.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + Groups: []string{"group3"}, + }, + }) + ctx.Next() + }, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "oauth-group.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusForbidden, recorder.Code) + assert.Equal(t, "", recorder.Header().Get("remote-user")) + assert.Equal(t, "", recorder.Header().Get("remote-name")) + assert.Equal(t, "", recorder.Header().Get("remote-email")) + assert.Equal(t, "", recorder.Header().Get("remote-groups")) + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL) + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true") + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group") + }, + }, + { + description: "OAuth not in required groups and browser", + middlewares: []gin.HandlerFunc{ + func(ctx *gin.Context) { + ctx.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + Groups: []string{"group3"}, + }, + }) + ctx.Next() + }, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "oauth-group.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("user-agent", browserUserAgent) + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, "groupErr=true") + assert.Contains(t, location, "oauth-group") + assert.Contains(t, location, runtime.AppURL) + }, + }, + { + description: "LDAP allowed group", + middlewares: []gin.HandlerFunc{ + func(ctx *gin.Context) { + ctx.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + Groups: []string{"group1"}, + }, + }) + ctx.Next() + }, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "ldap-group.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) + assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) + assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) + assert.Equal(t, "group1", recorder.Header().Get("remote-groups")) + }, + }, + { + description: "LDAP not in required groups and non browser", + middlewares: []gin.HandlerFunc{ + func(ctx *gin.Context) { + ctx.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + Groups: []string{"group3"}, + }, + }) + ctx.Next() + }, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "ldap-group.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusForbidden, recorder.Code) + assert.Equal(t, "", recorder.Header().Get("remote-user")) + assert.Equal(t, "", recorder.Header().Get("remote-name")) + assert.Equal(t, "", recorder.Header().Get("remote-email")) + assert.Equal(t, "", recorder.Header().Get("remote-groups")) + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL) + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true") + assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group") + }, + }, + { + description: "LDAP not in required groups and browser", + middlewares: []gin.HandlerFunc{ + func(ctx *gin.Context) { + ctx.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + Groups: []string{"group3"}, + }, + }) + ctx.Next() + }, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "ldap-group.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("user-agent", browserUserAgent) + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, "groupErr=true") + assert.Contains(t, location, "ldap-group") + assert.Contains(t, location, runtime.AppURL) + }, + }, + { + description: "Should add basic auth if it's in ACLs", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "basic-auth.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("authorization", "foo") // should be overridden by basic auth + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + authorizationHeader := recorder.Header().Get("Authorization") + assert.NotEmpty(t, authorizationHeader) + assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader) + }, + }, + { + description: "Authorization header should be preserved when not basic auth acls", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "test.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("authorization", "Bearer mytoken") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + authorizationHeader := recorder.Header().Get("Authorization") + assert.NotEmpty(t, authorizationHeader) + assert.Equal(t, "Bearer mytoken", authorizationHeader) + }, + }, + { + description: "Should add response headers if present", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "response-headers.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "bar", recorder.Header().Get("x-foo")) + }, + }, } store := memory.New() diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go index fe8cf48b..540a899a 100644 --- a/internal/controller/resources_controller_test.go +++ b/internal/controller/resources_controller_test.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/test" ) @@ -18,8 +19,12 @@ func TestResourcesController(t *testing.T) { err := os.MkdirAll(cfg.Resources.Path, 0777) require.NoError(t, err) + // create a "backup" of the original configuration to restore after each test + originalCfg := cfg.Resources + type testCase struct { description string + customCfg *model.ResourcesConfig run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } @@ -52,6 +57,32 @@ func TestResourcesController(t *testing.T) { assert.Equal(t, 404, recorder.Code) }, }, + { + description: "Ensure resources controller returns 404 when resources path is empty", + customCfg: &model.ResourcesConfig{ + Path: "", + Enabled: true, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/resources/testfile.txt", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 404, recorder.Code) + }, + }, + { + description: "Ensure resources controller returns 403 when resources are disabled", + customCfg: &model.ResourcesConfig{ + Path: cfg.Resources.Path, + Enabled: false, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/resources/testfile.txt", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 403, recorder.Code) + }, + }, } testFilePath := cfg.Resources.Path + "/testfile.txt" @@ -68,6 +99,14 @@ func TestResourcesController(t *testing.T) { group := router.Group("/") gin.SetMode(gin.TestMode) + // if custom configuration is provided, override the default config + if test.customCfg != nil { + cfg.Resources = *test.customCfg + } else { + // Reset to default configuration for each test + cfg.Resources = originalCfg + } + NewResourcesController(ResourcesControllerInput{ RouterGroup: group, Config: &cfg, diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 0ee63dfc..4f081b9b 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -41,6 +41,7 @@ func TestUserController(t *testing.T) { TOTPPending: true, }, }) + c.Next() } totpAttrCtx := func(c *gin.Context) { @@ -56,6 +57,7 @@ func TestUserController(t *testing.T) { TOTPPending: true, }, }) + c.Next() } simpleCtx := func(c *gin.Context) { @@ -70,6 +72,7 @@ func TestUserController(t *testing.T) { }, }, }) + c.Next() } store := memory.New() @@ -81,6 +84,40 @@ func TestUserController(t *testing.T) { } tests := []testCase{ + { + description: "Login should fail gracefully on invalid json", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`)) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 400, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Bad Request") + }, + }, + { + description: "Should fail on missing user", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := LoginRequest{ + Username: "nonexistentuser", + Password: "password", + } + loginReqBody, err := json.Marshal(loginReq) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 0) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + }, + }, { description: "Should be able to login with valid credentials", middlewares: []gin.HandlerFunc{}, @@ -242,6 +279,87 @@ func TestUserController(t *testing.T) { assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie }, }, + { + description: "Logout should be treated as valid without a session cookie", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/user/logout", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "TOTP should gracefully reject invalid json", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`)) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 400, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Bad Request") + }, + }, + { + description: "TOTP should fail on non-totp context", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + totpReq := TotpRequest{ + Code: "123456", + } + + totpReqBody, err := json.Marshal(totpReq) + require.NoError(t, err) + + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + }, + }, + { + description: "TOTP should fail when user in context doesn't exist", + middlewares: []gin.HandlerFunc{ + func(ctx *gin.Context) { + ctx.Set("context", &model.UserContext{ + Authenticated: false, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "idontexist", + Name: "Totpuser", + Email: "totpuser@example.com", + }, + TOTPPending: true, + }, + }) + ctx.Next() + }, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + totpReq := TotpRequest{ + Code: "123456", + } + + totpReqBody, err := json.Marshal(totpReq) + require.NoError(t, err) + + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + }, + }, { description: "Should be able to login with totp", middlewares: []gin.HandlerFunc{ diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 4098c152..8a969667 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http/httptest" + "net/url" "testing" "github.com/gin-gonic/gin" @@ -25,12 +26,14 @@ func TestWellKnownController(t *testing.T) { type testCase struct { description string + oidcEnabled bool run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } tests := []testCase{ { description: "Ensure well-known endpoint returns correct OIDC configuration", + oidcEnabled: true, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) router.ServeHTTP(recorder, req) @@ -39,7 +42,7 @@ func TestWellKnownController(t *testing.T) { res := OpenIDConnectConfiguration{} err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) expected := OpenIDConnectConfiguration{ Issuer: runtime.AppURL, @@ -55,8 +58,8 @@ func TestWellKnownController(t *testing.T) { TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"}, ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc", - RequestParameterSupported: true, RequestObjectSigningAlgValuesSupported: []string{"none"}, + RequestParameterSupported: true, } assert.Equal(t, expected, res) @@ -64,6 +67,7 @@ func TestWellKnownController(t *testing.T) { }, { description: "Ensure well-known endpoint returns correct JWKS", + oidcEnabled: true, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) router.ServeHTTP(recorder, req) @@ -72,19 +76,204 @@ func TestWellKnownController(t *testing.T) { decodedBody := make(map[string]any) err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) - assert.NoError(t, err) + require.NoError(t, err) keys, ok := decodedBody["keys"].([]any) - assert.True(t, ok) + require.True(t, ok) assert.Len(t, keys, 1) keyData, ok := keys[0].(map[string]any) - assert.True(t, ok) + require.True(t, ok) assert.Equal(t, "RSA", keyData["kty"]) assert.Equal(t, "sig", keyData["use"]) assert.Equal(t, "RS256", keyData["alg"]) }, }, + { + description: "Ensure openid configuration returns 500 on nil oidc service", + oidcEnabled: false, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 500, recorder.Code) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + require.NoError(t, err) + + assert.Equal(t, "OIDC service not configured", decodedBody["message"]) + }, + }, + { + description: "Ensure jwks endpoint returns 500 on nil oidc service", + oidcEnabled: false, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 500, recorder.Code) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + require.NoError(t, err) + + assert.Equal(t, "OIDC service not configured", decodedBody["message"]) + }, + }, + { + description: "Ensure webfinger returns 400 on invalid resource", + oidcEnabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 400, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + require.NoError(t, err) + + assert.Equal(t, "invalid resource", decodedBody["message"]) + }, + }, + { + description: "Ensure webfinger resource validator allows acct", + oidcEnabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + resource := "acct:testuser@example.com" + req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + }, + }, + { + description: "Ensure webfinger resource validator allows https", + oidcEnabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + resource := "https://example.com/testuser" + req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + }, + }, + { + description: "Ensure webfinger resource validator allows http", + oidcEnabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + resource := "http://example.com/testuser" + req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + }, + }, + { + description: "Webfinger should return no links when oidc is nil", + oidcEnabled: false, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + resource := "acct:testuser@example.com" + req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + require.NoError(t, err) + + links, ok := decodedBody["links"].([]any) + require.True(t, ok) + assert.Len(t, links, 0) + }, + }, + { + description: "Webfinger should return links when oidc is configured and no rel is provided", + oidcEnabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + resource := "acct:testuser@example.com" + req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + require.NoError(t, err) + + links, ok := decodedBody["links"].([]any) + require.True(t, ok) + assert.Len(t, links, 1) + + linkData, ok := links[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"]) + assert.Equal(t, runtime.AppURL, linkData["href"]) + }, + }, + { + description: "Webfinger should return links when oidc is configured and rel is provided", + oidcEnabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL) + rel := "http://openid.net/specs/connect/1.0/issuer" + req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + require.NoError(t, err) + + links, ok := decodedBody["links"].([]any) + require.True(t, ok) + assert.Len(t, links, 1) + + linkData, ok := links[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, rel, linkData["rel"]) + assert.Equal(t, runtime.AppURL, linkData["href"]) + }, + }, + { + description: "Webfinger should return no links when oidc is configured and rel is provided but does not match", + oidcEnabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + resource := "acct:testuser@example.com" + rel := "http://example.com/does-not-exist" + req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type")) + assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin")) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + require.NoError(t, err) + + links, ok := decodedBody["links"].([]any) + require.True(t, ok) + assert.Len(t, links, 0) + }, + }, } ctx := context.TODO() @@ -108,10 +297,15 @@ func TestWellKnownController(t *testing.T) { recorder := httptest.NewRecorder() - NewWellKnownController(WellKnownControllerInput{ - OIDCService: oidcService, + wellKnownControllerInput := WellKnownControllerInput{ RouterGroup: &router.RouterGroup, - }) + } + + if test.oidcEnabled { + wellKnownControllerInput.OIDCService = oidcService + } + + NewWellKnownController(wellKnownControllerInput) test.run(t, router, recorder) }) diff --git a/internal/test/test.go b/internal/test/test.go index df10f2b4..676501a4 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -76,6 +76,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { Bypass: []string{"10.10.10.10"}, }, }, + "ip_block": { + Config: model.AppConfig{ + Domain: "ip-block.example.com", + }, + IP: model.AppIP{ + Block: []string{"10.10.10.10"}, + }, + }, + "oauth_group": { + Config: model.AppConfig{ + Domain: "oauth-group.example.com", + }, + OAuth: model.AppOAuth{ + Whitelist: "testuser@example.com", + Groups: "group1,group2", + }, + }, + "ldap_group": { + Config: model.AppConfig{ + Domain: "ldap-group.example.com", + }, + LDAP: model.AppLDAP{ + Groups: "group1,group2", + }, + }, + "basic_auth": { + Config: model.AppConfig{ + Domain: "basic-auth.example.com", + }, + Response: model.AppResponse{ + BasicAuth: model.AppBasicAuth{ + Username: "test", + Password: "password", + }, + }, + }, + "response_headers": { + Config: model.AppConfig{ + Domain: "response-headers.example.com", + }, + Response: model.AppResponse{ + Headers: []string{"x-foo=bar"}, + }, + }, }, }