Skip to content

Commit 2235f29

Browse files
Julian AßmannJulianAssmann-SAP
authored andcommitted
Add interactive token attention visualization to attention_heads
1 parent f8c1c81 commit 2235f29

5 files changed

Lines changed: 1383 additions & 15859 deletions

File tree

python/Demonstration.ipynb

Lines changed: 1232 additions & 15853 deletions
Large diffs are not rendered by default.

python/circuitsvis/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Attention visualisations"""
2+
23
from typing import List, Optional, Union
34

45
import numpy as np
56
import torch
7+
68
from circuitsvis.utils.render import RenderedHTML, render
79

810

@@ -15,6 +17,7 @@ def attention_heads(
1517
negative_color: Optional[str] = None,
1618
positive_color: Optional[str] = None,
1719
mask_upper_tri: Optional[bool] = None,
20+
show_tokens: Optional[bool] = None,
1821
) -> RenderedHTML:
1922
"""Attention Heads
2023
@@ -41,6 +44,8 @@ def attention_heads(
4144
mask_upper_tri: Whether or not to mask the upper triangular portion of
4245
the attention patterns. Should be true for causal attention, false for
4346
bidirectional attention.
47+
show_tokens: Whether to display an interactive token view showing the
48+
attention from each token to all other tokens.
4449
4550
Returns:
4651
Html: Attention pattern visualization
@@ -54,6 +59,7 @@ def attention_heads(
5459
"positiveColor": positive_color,
5560
"tokens": tokens,
5661
"maskUpperTri": mask_upper_tri,
62+
"showTokens": show_tokens,
5763
}
5864

5965
return render(

react/package.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@
4747
"@storybook/addon-actions": "^6.5.14",
4848
"@storybook/addon-essentials": "^6.5.14",
4949
"@storybook/addon-interactions": "^6.5.14",
50-
"@storybook/addon-links": "^6.5.14",
51-
"@storybook/builder-webpack5": "^6.5.14",
50+
"@storybook/addon-links": "^9.0.16",
51+
"@storybook/builder-webpack5": "^9.0.16",
5252
"@storybook/manager-webpack5": "^6.5.14",
5353
"@storybook/preset-typescript": "^3.0.0",
54-
"@storybook/react": "^6.5.14",
54+
"@storybook/react": "^9.0.16",
5555
"@storybook/testing-library": "^0.0.13",
5656
"@tensorflow/tfjs-node": "^4.1.0",
5757
"@testing-library/jest-dom": "5.16.5",
@@ -79,7 +79,7 @@
7979
"eslint-plugin-prettier": "4.2.1",
8080
"eslint-plugin-react": "7.31.11",
8181
"eslint-plugin-react-hooks": "4.6.0",
82-
"eslint-plugin-storybook": "^0.6.8",
82+
"eslint-plugin-storybook": "^9.0.16",
8383
"eslint-plugin-testing-library": "5.9.1",
8484
"jest": "^29.3.1",
8585
"jest-canvas-mock": "^2.4.0",

react/src/attention/AttentionHeads.stories.tsx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,19 @@ InductionHeadsLayer.args = {
2222
tokens: mockTokens,
2323
attention: mockAttention
2424
};
25+
26+
export const InteractiveTokensDemo: ComponentStory<typeof AttentionHeads> =
27+
Template.bind({});
28+
InteractiveTokensDemo.args = {
29+
tokens: mockTokens,
30+
attention: mockAttention,
31+
showTokens: true
32+
};
33+
34+
export const WithoutTokens: ComponentStory<typeof AttentionHeads> =
35+
Template.bind({});
36+
WithoutTokens.args = {
37+
tokens: mockTokens,
38+
attention: mockAttention,
39+
showTokens: false
40+
};

react/src/attention/AttentionHeads.tsx

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import React from "react";
1+
import { Rank, tensor, Tensor4D } from "@tensorflow/tfjs";
2+
import React, { useMemo, useState } from "react";
23
import { Col, Container, Row } from "react-grid-system";
4+
import tinycolor from "tinycolor2";
35
import { AttentionPattern } from "./AttentionPattern";
6+
import { Tokens, TokensView } from "./components/AttentionTokens";
47
import { useHoverLock, UseHoverLockState } from "./components/useHoverLock";
58

69
/**
@@ -20,6 +23,52 @@ export function attentionHeadColor(
2023
return `hsla(${hue}, 70%, 50%, ${alpha})`;
2124
}
2225

26+
/**
27+
* Color the attention values by heads
28+
*
29+
* We want attention values to be colored by each head (i.e. becoming [heads x
30+
* dest_tokens x src_tokens x rgb_color_channel]). This way, when outputting an
31+
* image of just one attention head it will be colored (by the specific hue
32+
* assigned to that attention head) rather than grayscale.
33+
*
34+
* Importantly, when outputting an image that averages
35+
* several attention heads we can then also average over the colors (so that we
36+
* can see for each destination-source token pair which head is most important).
37+
* For example, if the specific pair is very red, it suggests that the red
38+
* attention head is most important for this destination-source token combination.
39+
*
40+
* @param attentionInput Attention input as [heads x dest_tokens x source_tokens] array
41+
*
42+
* @returns Tensor of the shape [heads x dest_tokens x src_tokens x
43+
* rgb_color_channel]
44+
*/
45+
export function colorAttentionTensors(attentionInput: number[][][]): Tensor4D {
46+
// Create a TensorFlow tensor from the attention data
47+
const attentionTensor = tensor<Rank.R3>(attentionInput); // [heads x dest_tokens x source_tokens]
48+
49+
const attention = attentionTensor.arraySync() as number[][][];
50+
51+
// Set the colors
52+
const colored = attention.map((head, headNumber) =>
53+
head.map((destination) =>
54+
destination.map((sourceAttention) => {
55+
// Color
56+
const attentionColor = tinycolor({
57+
h: (headNumber / attention.length) * 360, // Hue (degrees 0-360)
58+
s: 0.8, // Saturation (slightly off 100% to make less glaring)
59+
l: 1 - 0.75 * sourceAttention // Luminance (shows amount of attention)
60+
});
61+
62+
// Return as a nested list in the format [red, green, blue]
63+
const { r, g, b } = attentionColor.toRgb();
64+
return [r, g, b];
65+
})
66+
)
67+
);
68+
69+
return tensor(colored);
70+
}
71+
2372
/**
2473
* Attention Heads Selector
2574
*/
@@ -115,14 +164,34 @@ export function AttentionHeads({
115164
negativeColor,
116165
positiveColor,
117166
maskUpperTri = true,
118-
tokens
167+
tokens,
168+
showTokens = true
119169
}: AttentionHeadsProps) {
120170
// Attention head focussed state
121171
const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock(0);
122172

173+
// State for the token view type
174+
const [tokensView, setTokensView] = useState<TokensView>(
175+
TokensView.DESTINATION_TO_SOURCE
176+
);
177+
178+
// State for which token is focussed
179+
const {
180+
focused: focusedToken,
181+
onClick: onClickToken,
182+
onMouseEnter: onMouseEnterToken,
183+
onMouseLeave: onMouseLeaveToken
184+
} = useHoverLock();
185+
123186
const headNames =
124187
attentionHeadNames || attention.map((_, idx) => `Head ${idx}`);
125188

189+
// Color the attention values (by head) for token interaction
190+
const coloredAttention = useMemo(
191+
() => colorAttentionTensors(attention),
192+
[attention]
193+
);
194+
126195
return (
127196
<Container>
128197
<h3 style={{ marginBottom: 15 }}>
@@ -176,6 +245,51 @@ export function AttentionHeads({
176245
</Col>
177246
</Row>
178247

248+
{showTokens && (
249+
<Row>
250+
<Col xs={12}>
251+
<div className="tokens" style={{ marginTop: 20 }}>
252+
<h4 style={{ display: "inline-block", marginRight: 15 }}>
253+
Interactive Tokens
254+
<span style={{ fontWeight: "normal" }}>
255+
{" "}
256+
(hover/click to explore attention)
257+
</span>
258+
</h4>
259+
<select
260+
value={tokensView}
261+
onChange={(e) => setTokensView(e.target.value as TokensView)}
262+
style={{
263+
marginLeft: 10,
264+
padding: "5px 10px",
265+
borderRadius: 4,
266+
border: "1px solid #ccc"
267+
}}
268+
>
269+
<option value={TokensView.DESTINATION_TO_SOURCE}>
270+
Source ← Destination
271+
</option>
272+
<option value={TokensView.SOURCE_TO_DESTINATION}>
273+
Destination ← Source
274+
</option>
275+
</select>
276+
<div style={{ marginTop: 10 }}>
277+
<Tokens
278+
coloredAttention={coloredAttention}
279+
focusedHead={focused}
280+
focusedToken={focusedToken}
281+
onClickToken={onClickToken}
282+
onMouseEnterToken={onMouseEnterToken}
283+
onMouseLeaveToken={onMouseLeaveToken}
284+
tokens={tokens}
285+
tokensView={tokensView}
286+
/>
287+
</div>
288+
</div>
289+
</Col>
290+
</Row>
291+
)}
292+
179293
<Row></Row>
180294
</Container>
181295
);
@@ -268,4 +382,13 @@ export interface AttentionHeadsProps {
268382
* Must be the same length as the list of values.
269383
*/
270384
tokens: string[];
385+
386+
/**
387+
* Show interactive tokens
388+
*
389+
* Whether to display the interactive token layer with hover/click functionality.
390+
*
391+
* @default true
392+
*/
393+
showTokens?: boolean;
271394
}

0 commit comments

Comments
 (0)