1- import React from "react" ;
1+ import { Rank , tensor , Tensor4D } from "@tensorflow/tfjs" ;
2+ import React , { useMemo , useState } from "react" ;
23import { Col , Container , Row } from "react-grid-system" ;
4+ import tinycolor from "tinycolor2" ;
35import { AttentionPattern } from "./AttentionPattern" ;
6+ import { Tokens , TokensView } from "./components/AttentionTokens" ;
47import { useHoverLock , UseHoverLockState } from "./components/useHoverLock" ;
58
69/**
@@ -20,6 +23,52 @@ export function attentionHeadColor(
2023 return `hsla(${ hue } , 70%, 50%, ${ alpha } )` ;
2124}
2225
26+ /**
27+ * Color the attention values by heads
28+ *
29+ * We want attention values to be colored by each head (i.e. becoming [heads x
30+ * dest_tokens x src_tokens x rgb_color_channel]). This way, when outputting an
31+ * image of just one attention head it will be colored (by the specific hue
32+ * assigned to that attention head) rather than grayscale.
33+ *
34+ * Importantly, when outputting an image that averages
35+ * several attention heads we can then also average over the colors (so that we
36+ * can see for each destination-source token pair which head is most important).
37+ * For example, if the specific pair is very red, it suggests that the red
38+ * attention head is most important for this destination-source token combination.
39+ *
40+ * @param attentionInput Attention input as [heads x dest_tokens x source_tokens] array
41+ *
42+ * @returns Tensor of the shape [heads x dest_tokens x src_tokens x
43+ * rgb_color_channel]
44+ */
45+ export function colorAttentionTensors ( attentionInput : number [ ] [ ] [ ] ) : Tensor4D {
46+ // Create a TensorFlow tensor from the attention data
47+ const attentionTensor = tensor < Rank . R3 > ( attentionInput ) ; // [heads x dest_tokens x source_tokens]
48+
49+ const attention = attentionTensor . arraySync ( ) as number [ ] [ ] [ ] ;
50+
51+ // Set the colors
52+ const colored = attention . map ( ( head , headNumber ) =>
53+ head . map ( ( destination ) =>
54+ destination . map ( ( sourceAttention ) => {
55+ // Color
56+ const attentionColor = tinycolor ( {
57+ h : ( headNumber / attention . length ) * 360 , // Hue (degrees 0-360)
58+ s : 0.8 , // Saturation (slightly off 100% to make less glaring)
59+ l : 1 - 0.75 * sourceAttention // Luminance (shows amount of attention)
60+ } ) ;
61+
62+ // Return as a nested list in the format [red, green, blue]
63+ const { r, g, b } = attentionColor . toRgb ( ) ;
64+ return [ r , g , b ] ;
65+ } )
66+ )
67+ ) ;
68+
69+ return tensor ( colored ) ;
70+ }
71+
2372/**
2473 * Attention Heads Selector
2574 */
@@ -115,14 +164,34 @@ export function AttentionHeads({
115164 negativeColor,
116165 positiveColor,
117166 maskUpperTri = true ,
118- tokens
167+ tokens,
168+ showTokens = true
119169} : AttentionHeadsProps ) {
120170 // Attention head focussed state
121171 const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock ( 0 ) ;
122172
173+ // State for the token view type
174+ const [ tokensView , setTokensView ] = useState < TokensView > (
175+ TokensView . DESTINATION_TO_SOURCE
176+ ) ;
177+
178+ // State for which token is focussed
179+ const {
180+ focused : focusedToken ,
181+ onClick : onClickToken ,
182+ onMouseEnter : onMouseEnterToken ,
183+ onMouseLeave : onMouseLeaveToken
184+ } = useHoverLock ( ) ;
185+
123186 const headNames =
124187 attentionHeadNames || attention . map ( ( _ , idx ) => `Head ${ idx } ` ) ;
125188
189+ // Color the attention values (by head) for token interaction
190+ const coloredAttention = useMemo (
191+ ( ) => colorAttentionTensors ( attention ) ,
192+ [ attention ]
193+ ) ;
194+
126195 return (
127196 < Container >
128197 < h3 style = { { marginBottom : 15 } } >
@@ -176,6 +245,51 @@ export function AttentionHeads({
176245 </ Col >
177246 </ Row >
178247
248+ { showTokens && (
249+ < Row >
250+ < Col xs = { 12 } >
251+ < div className = "tokens" style = { { marginTop : 20 } } >
252+ < h4 style = { { display : "inline-block" , marginRight : 15 } } >
253+ Interactive Tokens
254+ < span style = { { fontWeight : "normal" } } >
255+ { " " }
256+ (hover/click to explore attention)
257+ </ span >
258+ </ h4 >
259+ < select
260+ value = { tokensView }
261+ onChange = { ( e ) => setTokensView ( e . target . value as TokensView ) }
262+ style = { {
263+ marginLeft : 10 ,
264+ padding : "5px 10px" ,
265+ borderRadius : 4 ,
266+ border : "1px solid #ccc"
267+ } }
268+ >
269+ < option value = { TokensView . DESTINATION_TO_SOURCE } >
270+ Source ← Destination
271+ </ option >
272+ < option value = { TokensView . SOURCE_TO_DESTINATION } >
273+ Destination ← Source
274+ </ option >
275+ </ select >
276+ < div style = { { marginTop : 10 } } >
277+ < Tokens
278+ coloredAttention = { coloredAttention }
279+ focusedHead = { focused }
280+ focusedToken = { focusedToken }
281+ onClickToken = { onClickToken }
282+ onMouseEnterToken = { onMouseEnterToken }
283+ onMouseLeaveToken = { onMouseLeaveToken }
284+ tokens = { tokens }
285+ tokensView = { tokensView }
286+ />
287+ </ div >
288+ </ div >
289+ </ Col >
290+ </ Row >
291+ ) }
292+
179293 < Row > </ Row >
180294 </ Container >
181295 ) ;
@@ -268,4 +382,13 @@ export interface AttentionHeadsProps {
268382 * Must be the same length as the list of values.
269383 */
270384 tokens : string [ ] ;
385+
386+ /**
387+ * Show interactive tokens
388+ *
389+ * Whether to display the interactive token layer with hover/click functionality.
390+ *
391+ * @default true
392+ */
393+ showTokens ?: boolean ;
271394}
0 commit comments