Skip to content

Commit 5847dc2

Browse files
CiMLoop ISAAC + Wang + Basic Analog
1 parent 731666c commit 5847dc2

22 files changed

Lines changed: 11106 additions & 993 deletions

examples/arches/compute_in_memory/_include.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ variables_global: &variables_global
5353

5454
average_input_bits_per_slice: encoded_input_bits / n_input_slices
5555
average_weight_bits_per_slice: encoded_weight_bits / n_weight_slices
56+
average_input_bits_per_sliced_psum: encoded_input_bits / n_sliced_psums
57+
average_weight_bits_per_sliced_psum: encoded_weight_bits / n_sliced_psums
58+
average_output_bits_per_sliced_psum: encoded_output_bits / n_sliced_psums
5659

5760
# This is for the bitwise-multiplication of the input and weight slices
5861
n_virtual_macs: max_input_bits_per_slice * max_weight_bits_per_slice * encoded_output_bits
@@ -71,4 +74,4 @@ variables_global: &variables_global
7174

7275
n_input_slices: max(ceil(in_b / max_input_bits_per_slice), min_input_slices)
7376
n_weight_slices: max(ceil(w_b / max_weight_bits_per_slice), min_weight_slices)
74-
n_sliced_psums: n_input_slices * n_weight_slices
77+
n_sliced_psums: n_input_slices * n_weight_slices

examples/arches/compute_in_memory/_include_functions.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,25 @@
22

33

44
def get_array_fanout_reuse_input(spec: af.Spec) -> int:
5-
n_rows = 1
5+
"""Get total fanout of array spatial dims that reuse input (= columns)."""
6+
n = 1
67
for leaf in spec.arch.get_nodes_of_type(af.arch.Leaf):
7-
if "array_reuse_input" in leaf.spatial:
8-
fanout = leaf.spatial["array_reuse_input"]["fanout"]
9-
assert isinstance(fanout, (int, float)), f"fanout {leaf.name}.spatial.array_reuse_input.fanout is not a number"
10-
n_rows *= fanout
11-
return n_rows
8+
for sp in leaf.spatial:
9+
if sp.name.endswith("ARRAY_COLUMNS") or sp.name.endswith("ARRAY_ROWS"):
10+
if str(sp.may_reuse) == "input" or str(sp.reuse) == "input":
11+
n *= sp.fanout
12+
return n
1213

1314

1415
def get_array_fanout_reuse_output(spec: af.Spec) -> int:
15-
n_cols = 1
16+
"""Get total fanout of array spatial dims that reuse output (= rows)."""
17+
n = 1
1618
for leaf in spec.arch.get_nodes_of_type(af.arch.Leaf):
17-
if "array_reuse_output" in leaf.spatial:
18-
fanout = leaf.spatial["array_reuse_output"]["fanout"]
19-
assert isinstance(fanout, (int, float)), f"fanout {leaf.name}.spatial.array_reuse_output.fanout is not a number"
20-
n_cols *= fanout
21-
return n_cols
19+
for sp in leaf.spatial:
20+
if sp.name.endswith("ARRAY_COLUMNS") or sp.name.endswith("ARRAY_ROWS"):
21+
if str(sp.may_reuse) == "output" or str(sp.reuse) == "output":
22+
n *= sp.fanout
23+
return n
2224

2325

2426
def get_array_fanout_total(spec: af.Spec) -> int:
@@ -33,10 +35,12 @@ def get_array_fanout_total(spec: af.Spec) -> int:
3335
from math import log2
3436
from typing import List, NamedTuple, Union
3537

38+
3639
class ProbableBits(NamedTuple):
3740
bits: list
3841
probability: float
3942

43+
4044
# ==============================================================================
4145
# Encoding functions
4246
# ==============================================================================
@@ -55,6 +59,7 @@ def magnitude_encode_hist(weights) -> List[ProbableBits]:
5559
encoded.append(ProbableBits(to_bits_unsigned(abs(normed), nbits)[1:], w))
5660
return norm_encoded_hist(encoded)
5761

62+
5863
def two_part_magnitude_encode_hist(weights):
5964
"""
6065
Two (devices, timesteps, components, etc.) encode each signed value. If the
@@ -69,6 +74,7 @@ def two_part_magnitude_encode_hist(weights):
6974
m2.append(ProbableBits([0] * len(e.bits), e.probability / 2))
7075
return m2
7176

77+
7278
def offset_encode_hist(weights):
7379
"""
7480
A signed value is encoded as the the value minus the negative minimum value.
@@ -132,10 +138,12 @@ def zero_gated_xnor_encode_hist(weights):
132138
)
133139
return encoded
134140

141+
135142
# ==============================================================================
136143
# Helper functions
137144
# ==============================================================================
138145

146+
139147
def assert_hist_pow2_minus1(hist):
140148
x = 1
141149
while x <= len(hist):

examples/arches/compute_in_memory/_load_spec.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def get_spec(
99
arch_name: str,
1010
compare_with_arch_name: str | None = None,
1111
add_dummy_main_memory: bool = False,
12+
n_macros: int = 1,
1213
) -> af.Spec:
1314
"""
1415
Gets the spec for the given architecture. If `compare_with_arch_name` is given, the
@@ -22,7 +23,8 @@ def get_spec(
2223
compare_with_arch_name: str | None
2324
The name of the architecture to compare with. If not given, variables will be
2425
taken from the given `arch_name`.
25-
26+
n_macros: int
27+
The number of macros to use in the architecture.
2628
Returns
2729
-------
2830
spec: af.Spec
@@ -33,6 +35,7 @@ def get_spec(
3335
else:
3436
compare_with_name = compare_with_arch_name
3537

38+
arch_name_base = arch_name
3639
arch_name = os.path.join(THIS_SCRIPT_DIR, f"{arch_name}.yaml")
3740
compare_with_name = os.path.join(THIS_SCRIPT_DIR, f"{compare_with_name}.yaml")
3841
variables = af.Variables.from_yaml(arch_name, top_key="variables")
@@ -43,15 +46,27 @@ def get_spec(
4346
spec.config.expression_custom_functions.append(
4447
os.path.join(THIS_SCRIPT_DIR, "_include_functions.py")
4548
)
49+
# Load architecture-specific helper functions if they exist
50+
arch_helpers = os.path.join(
51+
THIS_SCRIPT_DIR, f"{arch_name_base}_helper_functions.py"
52+
)
53+
if os.path.exists(arch_helpers):
54+
spec.config.expression_custom_functions.append(arch_helpers)
4655
spec.config.component_models.append(
4756
os.path.join(THIS_SCRIPT_DIR, "components/*.py")
4857
)
58+
if n_macros > 1:
59+
macro = af.arch.Container(
60+
name="MacroAuto",
61+
spatial=[{"name": "macro", "fanout": n_macros, "power_gateable": True}],
62+
)
63+
spec.arch.nodes.insert(0, macro)
4964
if add_dummy_main_memory:
5065
main_memory = af.arch.Memory(
5166
name="MainMemory",
5267
component_class="Dummy",
5368
size=float("inf"),
54-
tensors={"keep": "~weight"}
69+
tensors={"keep": "~weight"},
5570
)
5671
spec.arch.nodes.insert(0, main_memory)
5772
return spec
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Basic analog CiM macro.
2+
# A simple analog compute-in-memory macro demonstrating the fundamental
3+
# components of an analog CiM array: row drivers (DAC), column drivers,
4+
# ADC, memory cells, and a virtualized MAC compute unit.
5+
6+
{{include_text('_include.yaml')}}
7+
{{add_to_path('./memory_cells')}}
8+
9+
arch:
10+
variables:
11+
<<: *variables_global
12+
13+
# ===========================================================================
14+
# Encoding-dependent parameters
15+
# ===========================================================================
16+
encoded_input_bits: input_bits
17+
encoded_weight_bits: weight_bits
18+
encoded_output_bits: output_bits
19+
20+
input_encoding_func: offset_encode_hist
21+
weight_encoding_func: offset_encode_hist
22+
23+
# For accuracy model. Can in-array accumulation include signed values?
24+
# Signed accumulation not compatible with offset encoding (since offset
25+
# encoding makes values non-negative).
26+
signed_sum_across_inputs: False
27+
signed_sum_across_weights: False
28+
29+
# ===========================================================================
30+
# Architecture & CiM Array Structure
31+
# ===========================================================================
32+
cim_unit_width_cells: 1
33+
cim_unit_depth_cells: 1
34+
bits_per_cell: 8
35+
36+
# ===========================================================================
37+
# Data Converters
38+
# ===========================================================================
39+
adc_resolution: 8
40+
voltage_dac_resolution: 1
41+
temporal_dac_resolution: 8
42+
43+
n_adc_per_bank: 2
44+
45+
# ===========================================================================
46+
# Hardware
47+
# ===========================================================================
48+
cycle_period: 1e-7 * voltage_latency_scale
49+
read_pulse_width: 1e-9
50+
51+
extra_attributes_for_all_component_models:
52+
<<: *cim_component_attributes
53+
tech_node: tech_node
54+
cycle_period: cycle_period
55+
56+
nodes:
57+
- !Toll # ADC: Column readout
58+
name: ADC
59+
tensors: {keep: output}
60+
direction: up
61+
bits_per_action: average_output_bits_per_sliced_psum
62+
component_class: ADC
63+
energy_scale: adc_energy_scale
64+
area_scale: adc_area_scale
65+
extra_attributes_for_component_model:
66+
n_bits: adc_resolution
67+
throughput_scale: 1
68+
throughput: 1 / cycle_period * cols_active_at_once * throughput_scale
69+
70+
- !Toll # Column drivers precharge the array columns
71+
name: ColumnDrivers
72+
tensors: {keep: output}
73+
direction: up
74+
bits_per_action: average_output_bits_per_sliced_psum
75+
component_class: ArrayColumnDrivers
76+
77+
- !Toll # Row drivers feed inputs onto the rows of the array
78+
name: RowDrivers
79+
tensors: {keep: input}
80+
direction: down
81+
bits_per_action: average_input_bits_per_slice
82+
component_class: ArrayRowDrivers
83+
extra_attributes_for_component_model:
84+
temporal_spiking: true
85+
86+
# This memory catches sliding windows that may be sent spatially in the array. E.g.,
87+
# convolution steps spatially unrolled onto columns with overlapping windows. Size =
88+
# one input value per row. no_resend_to_below prevents reuse across temporal
89+
# iterations.
90+
- !Memory
91+
name: DummyRowDriverMemory
92+
component_class: Dummy
93+
size: input.bits_per_value * array_parallel_inputs
94+
tensors: {keep: input, no_resend_to_below: input}
95+
96+
- !Container # Each column stores a different weight slice. Columns share inputs.
97+
name: Column
98+
spatial:
99+
- name: column_ARRAY_COLUMNS
100+
fanout: 32
101+
may_reuse: input
102+
min_usage: 1
103+
usage_scale: n_weight_slices
104+
105+
- !Container # Each row receives a different input slice. Rows share outputs.
106+
name: Row
107+
spatial:
108+
- name: row_ARRAY_ROWS
109+
fanout: 32
110+
may_reuse: output
111+
reuse: output
112+
min_usage: 1
113+
114+
# CiM unit stores weights and computes MACs.
115+
- !Memory
116+
name: CimUnit
117+
tensors: {keep: weight, no_refetch_from_above: weight, force_memory_hierarchy_order: False}
118+
size: cim_unit_width_cells * cim_unit_depth_cells * bits_per_cell * n_weight_slices
119+
bits_per_action: average_weight_bits_per_sliced_psum
120+
n_parallel_instances: n_weight_slices
121+
component_class: MemoryCell
122+
actions: [{name: read, latency: cycle_period}]
123+
extra_attributes_for_component_model:
124+
n_instances: cim_unit_width_cells * cim_unit_depth_cells
125+
126+
# We account for compute energy in the CimUnit reads
127+
- !Compute
128+
name: FreeCompute
129+
component_class: Dummy
130+
enabled: len(All) == 3
131+
132+
133+
# These variables pertain to the workload, microarch, and circuits.
134+
variables:
135+
inputs_hist: [0, 0, 0, 3, 2, 1, 0]
136+
weights_hist: ([1] * 15)
137+
outputs_hist: inputs_hist
138+
139+
## Microarch ----------------------------------------------------------------
140+
supported_input_bits: 8
141+
supported_weight_bits: 8
142+
supported_output_bits: 8
143+
min_supported_input_bits: 1
144+
min_supported_weight_bits: 1
145+
min_supported_output_bits: 1
146+
147+
# Circuits ------------------------------------------------------------------
148+
voltage: 1
149+
tech_node: 65e-9 # 65nm
150+
cell_config: "{{find_path('rram_example.yaml')}}"
151+
voltage_energy_scale: voltage ** 2
152+
voltage_latency_scale: 1 / voltage
153+
154+
# Calibration ---------------------------------------------------------------
155+
adc_energy_scale: voltage_energy_scale
156+
adc_area_scale: 1
157+
row_col_drivers_area_scale: 1
158+
159+
160+
# This workload is sized to get peak throughput & energy efficiency.
161+
# 32 columns × 32 rows fills the array.
162+
workload:
163+
rank_sizes:
164+
M: 1
165+
N: 32
166+
K: 32
167+
168+
einsums:
169+
- name: Matmul
170+
tensor_accesses:
171+
- {name: input, projection: [m, k], bits_per_value: 8}
172+
- {name: weight, projection: [k, n], bits_per_value: 8}
173+
- {name: output, projection: [m, n], output: True, bits_per_value: 8}
174+
175+
renames: {} # Not needed for this workload

0 commit comments

Comments
 (0)