Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ help:
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test Run the unit tests."
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
@echo "frontend-install Install the pnpm modules needed for the front end"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-install Install the pnpm modules needed for the frontend"
@echo "frontend-build Build the frontend for localhost:9090"
@echo "frontend-test Run the frontend test suite once"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "frontend-prettier Format the frontend using lint:prettier"
@echo "wheel Build the wheel for the current version"
@echo "frontend-lint Run frontend checks and fixable lint/format steps"
@echo "wheel Build the wheel for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
@echo "docs Serve the mkdocs site with live reload"
Expand Down Expand Up @@ -57,6 +58,10 @@ frontend-install:
frontend-build:
cd invokeai/frontend/web && pnpm build

# Run the frontend test suite once
frontend-test:
cd invokeai/frontend/web && pnpm run test:run

# Run the frontend in dev mode
frontend-dev:
cd invokeai/frontend/web && pnpm dev
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ async def login(
user_id=user.user_id,
email=user.email,
is_admin=user.is_admin,
remember_me=request.remember_me,
)
token = create_access_token(token_data, expires_delta)

Expand Down
46 changes: 46 additions & 0 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,50 @@ async def lifespan(app: FastAPI):
)


class SlidingWindowTokenMiddleware(BaseHTTPMiddleware):
"""Refresh the JWT token on each authenticated response.

When a request includes a valid Bearer token, the response includes a
X-Refreshed-Token header with a new token that has a fresh expiry.
This implements sliding-window session expiry: the session only expires
after a period of *inactivity*, not a fixed time after login.
"""

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
response = await call_next(request)

# Only refresh on mutating requests (POST/PUT/PATCH/DELETE) — these indicate
# genuine user activity. GET requests are often background fetches (RTK Query
# cache revalidation, refetch-on-focus, etc.) and should not reset the
# inactivity timer.
if response.status_code < 400 and request.method in ("POST", "PUT", "PATCH", "DELETE"):
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
from datetime import timedelta

from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_NORMAL, TOKEN_EXPIRATION_REMEMBER_ME
from invokeai.app.services.auth.token_service import create_access_token, verify_token

token_data = verify_token(token)
if token_data is not None:
# Use the remember_me claim from the token to determine the
# correct refresh duration. This avoids the bug where a 7-day
# token with <24h remaining would be silently downgraded to 1 day.
if token_data.remember_me:
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME)
else:
expires_delta = timedelta(days=TOKEN_EXPIRATION_NORMAL)

new_token = create_access_token(token_data, expires_delta)
response.headers["X-Refreshed-Token"] = new_token
except Exception:
pass # Don't fail the request if token refresh fails

return response


class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
"""When a request is made to the root path with a query string, redirect to the root path without the query string.

Expand All @@ -99,6 +143,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):

# Add the middleware
app.add_middleware(RedirectRootWithQueryStringMiddleware)
app.add_middleware(SlidingWindowTokenMiddleware)


# Add event handler
Expand All @@ -117,6 +162,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
expose_headers=["X-Refreshed-Token"],
)

app.add_middleware(GZipMiddleware, minimum_size=1000)
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/services/auth/token_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class TokenData(BaseModel):
user_id: str
email: str
is_admin: bool
remember_me: bool = False


def set_jwt_secret(secret: str) -> None:
Expand Down
3 changes: 2 additions & 1 deletion invokeai/frontend/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"scripts": {
"dev": "vite dev",
"dev:host": "vite dev --host",
"build": "pnpm run lint && vite build",
"build": "pnpm run lint && vitest run && vite build",
"typegen": "node scripts/typegen.js",
"preview": "vite preview",
"lint:knip": "knip --tags=-knipignore",
Expand All @@ -35,6 +35,7 @@
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",
"test": "vitest",
"test:run": "vitest run",
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch"
},
Expand Down
3 changes: 2 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"rememberMe": "Remember me for 7 days",
"signIn": "Sign In",
"signingIn": "Signing in...",
"loginFailed": "Login failed. Please check your credentials."
"loginFailed": "Login failed. Please check your credentials.",
"sessionExpired": "Your credentials have expired. Please log in again to resume."
},
"setup": {
"title": "Welcome to InvokeAI",
Expand Down
13 changes: 13 additions & 0 deletions invokeai/frontend/web/src/common/hooks/focus.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { describe, expect, it } from 'vitest';

import { getFocusedRegion, setFocusedRegion } from './focus';

describe('focus regions', () => {
it('supports the workflows region', () => {
setFocusedRegion('workflows');
expect(getFocusedRegion()).toBe('workflows');

setFocusedRegion(null);
expect(getFocusedRegion()).toBe(null);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import {
Text,
VStack,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { setCredentials } from 'features/auth/store/authSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectSessionExpired, setCredentials } from 'features/auth/store/authSlice';
import type { ChangeEvent, FormEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
Expand All @@ -29,6 +29,7 @@ export const LoginPage = memo(() => {
const [rememberMe, setRememberMe] = useState(true);
const [login, { isLoading, error }] = useLoginMutation();
const dispatch = useAppDispatch();
const sessionExpired = useAppSelector(selectSessionExpired);
const { data: setupStatus, isLoading: isLoadingSetup } = useGetSetupStatusQuery();

// Redirect to app if multiuser mode is disabled
Expand Down Expand Up @@ -114,6 +115,12 @@ export const LoginPage = memo(() => {
{t('auth.login.title')}
</Heading>

{sessionExpired && (
<Flex p={3} borderRadius="md" bg="warning.600" color="white" fontSize="sm" justifyContent="center">
<Text fontWeight="semibold">{t('auth.login.sessionExpired')}</Text>
</Flex>
)}

<FormControl isRequired isInvalid={!!errorMessage}>
<FormLabel>{t('auth.login.email')}</FormLabel>
<Input
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Center, Spinner } from '@invoke-ai/ui-library';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { logout, setCredentials } from 'features/auth/store/authSlice';
import { logout, sessionExpiredLogout, setCredentials } from 'features/auth/store/authSlice';
import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';
import { useNavigate } from 'react-router-dom';
Expand Down Expand Up @@ -33,13 +33,42 @@ export const ProtectedRoute = memo(({ children, requireAdmin = false }: PropsWit
});

useEffect(() => {
// If we have a token but fetching user failed, token is invalid - logout
if (userError && isAuthenticated) {
dispatch(logout());
// Only treat 401 as session expiry. Other errors (500, network, etc.) are
// transient and should not force logout — the 401 handler in dynamicBaseQuery
// already covers the actual expiry case.
if (userError && isAuthenticated && 'status' in userError && userError.status === 401) {
dispatch(sessionExpiredLogout());
navigate('/login', { replace: true });
}
}, [userError, isAuthenticated, dispatch, navigate]);

// Detect when auth_token is removed from localStorage (e.g. by another tab,
// browser devtools, or token expiry cleanup). The 'storage' event fires when
// localStorage is modified by another context; we also poll periodically to
// catch same-tab deletions (which don't trigger the storage event).
useEffect(() => {
if (!multiuserEnabled || !isAuthenticated) {
return;
}

const checkToken = () => {
if (!localStorage.getItem('auth_token') && isAuthenticated) {
dispatch(sessionExpiredLogout());
navigate('/login', { replace: true });
}
};

// Listen for cross-tab localStorage changes
window.addEventListener('storage', checkToken);
// Poll for same-tab deletions (e.g. browser console)
const interval = setInterval(checkToken, 5000);

return () => {
window.removeEventListener('storage', checkToken);
clearInterval(interval);
};
}, [multiuserEnabled, isAuthenticated, dispatch, navigate]);

useEffect(() => {
// If we successfully fetched user data, update auth state
if (currentUser && token && !user) {
Expand Down
18 changes: 16 additions & 2 deletions invokeai/frontend/web/src/features/auth/store/authSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const zAuthState = z.object({
token: z.string().nullable(),
user: zUser.nullable(),
isLoading: z.boolean(),
sessionExpired: z.boolean(),
});

type User = z.infer<typeof zUser>;
Expand All @@ -34,6 +35,7 @@ const initialState: AuthState = {
token: getStoredAuthToken(),
user: null,
isLoading: false,
sessionExpired: false,
};

const getInitialAuthState = (): AuthState => initialState;
Expand All @@ -46,6 +48,7 @@ const authSlice = createSlice({
state.token = action.payload.token;
state.user = action.payload.user;
state.isAuthenticated = true;
state.sessionExpired = false;
if (typeof window !== 'undefined' && window.localStorage) {
localStorage.setItem('auth_token', action.payload.token);
}
Expand All @@ -54,6 +57,16 @@ const authSlice = createSlice({
state.token = null;
state.user = null;
state.isAuthenticated = false;
state.sessionExpired = false;
if (typeof window !== 'undefined' && window.localStorage) {
localStorage.removeItem('auth_token');
}
},
sessionExpiredLogout: (state) => {
state.token = null;
state.user = null;
state.isAuthenticated = false;
state.sessionExpired = true;
if (typeof window !== 'undefined' && window.localStorage) {
localStorage.removeItem('auth_token');
}
Expand All @@ -64,7 +77,7 @@ const authSlice = createSlice({
},
});

export const { setCredentials, logout, setLoading } = authSlice.actions;
export const { setCredentials, logout, sessionExpiredLogout, setLoading } = authSlice.actions;

export const authSliceConfig: SliceConfig<typeof authSlice> = {
slice: authSlice,
Expand All @@ -73,11 +86,12 @@ export const authSliceConfig: SliceConfig<typeof authSlice> = {
persistConfig: {
migrate: () => getInitialAuthState(),
// Don't persist auth state - token is stored in localStorage
persistDenylist: ['isAuthenticated', 'token', 'user', 'isLoading'],
persistDenylist: ['isAuthenticated', 'token', 'user', 'isLoading', 'sessionExpired'],
},
};

export const selectIsAuthenticated = (state: { auth: AuthState }) => state.auth.isAuthenticated;
export const selectCurrentUser = (state: { auth: AuthState }) => state.auth.user;
export const selectAuthToken = (state: { auth: AuthState }) => state.auth.token;
export const selectIsAuthLoading = (state: { auth: AuthState }) => state.auth.isLoading;
export const selectSessionExpired = (state: { auth: AuthState }) => state.auth.sessionExpired;
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addInpaintMaskFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({ type: 'inpaint_mask' });
const addResizedControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'control_layer',
withResize: true,
Expand All @@ -25,39 +26,47 @@ export const CanvasDropArea = memo(() => {
<>
<Grid
gridTemplateRows="1fr 1fr"
gridTemplateColumns="1fr 1fr"
gridTemplateColumns="repeat(6, 1fr)"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
pointerEvents="none"
>
<GridItem position="relative">
<GridItem position="relative" colSpan={3}>
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRasterLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRasterLayer')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<GridItem position="relative" colSpan={3}>
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addControlLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newControlLayer')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<GridItem position="relative" colSpan={2}>
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addRegionalGuidanceReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newRegionalReferenceImage')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<GridItem position="relative" colSpan={2}>
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addInpaintMaskFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newInpaintMask')}
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative" colSpan={2}>
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addResizedControlLayerFromImageDndTargetData}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ export abstract class CanvasEntityAdapterBase<T extends CanvasEntityState, U ext
this.renderer.updateCompositingRectSize();
this.renderer.updateCompositingRectPosition();
this.renderer.updateCompositingRectFill();
this.renderer.updateOpacity();
}
this.renderer.syncKonvaCache();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,6 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
};

updateOpacity = throttle(() => {
if (!this.parent.konva.layer.visible()) {
return;
}

this.log.trace('Updating opacity');

const opacity = this.parent.state.opacity;
Expand Down
Loading
Loading