-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathactivations.py
More file actions
63 lines (56 loc) · 2.42 KB
/
activations.py
File metadata and controls
63 lines (56 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Activations visualizations"""
from typing import List, Union, Optional
import numpy as np
import torch
from circuitsvis.utils.render import RenderedHTML, render
def text_neuron_activations(
tokens: Union[List[str], List[List[str]]],
activations: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
first_dimension_name: Optional[str] = "Layer",
second_dimension_name: Optional[str] = "Neuron",
first_dimension_labels: Optional[List[str]] = None,
second_dimension_labels: Optional[List[str]] = None,
first_dimension_default: Optional[int] = 0,
second_dimension_default: Optional[int] = 0,
show_selectors: Optional[bool] = True,
) -> RenderedHTML:
"""Show activations (colored by intensity) for each token in a text or set
of texts.
Includes drop-downs for layer and neuron numbers.
Args:
tokens: List of tokens if single sample (e.g. `["A", "person"]`) or list of lists of tokens (e.g. `[[["A", "person"], ["is", "walking"]]]`)
activations: Activations of the shape [tokens x layers x neurons] if
single sample or list of [tokens x layers x neurons] if multiple samples
Returns:
Html: Text neuron activations visualization
"""
# Verify that activations and tokens have the right shape and convert to
# nested lists
if isinstance(activations, (np.ndarray, torch.Tensor)):
assert (
activations.ndim == 3
), "activations must be of shape [tokens x layers x neurons]"
activations_list = activations.tolist()
elif isinstance(activations, list):
activations_list = []
for act in activations:
assert (
act.ndim == 3
), "activations must be of shape [tokens x layers x neurons]"
activations_list.append(act.tolist())
else:
raise TypeError(
f"activations must be of type np.ndarray, torch.Tensor, or list, not {type(activations)}"
)
return render(
"TextNeuronActivations",
tokens=tokens,
activations=activations_list,
firstDimensionName=first_dimension_name,
secondDimensionName=second_dimension_name,
firstDimensionLabels=first_dimension_labels,
secondDimensionLabels=second_dimension_labels,
firstDimensionDefault=first_dimension_default,
secondDimensionDefault=second_dimension_default,
showSelectors=show_selectors,
)