Skip to content

Commit 0f86c88

Browse files
authored
Fix ACT layer gradient computation on CUDA (#3128)
Move effective_weights accumulation into update_act_state() and finalize_act_output() kernels so that the weights used in backward() match the actual forward pass computation. Previously, true_effective_weights_ was computed on the host using remainders_/cumulative_halting_ values that became stale after CUDA kernels updated them on the device. This caused gradient mismatches in test_layer() on GPU builds.
1 parent 49b7cba commit 0f86c88

7 files changed

Lines changed: 34 additions & 35 deletions

File tree

dlib/cuda/cpu_dlib.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3267,6 +3267,7 @@ namespace dlib
32673267
resizable_tensor& cumulative_halting,
32683268
resizable_tensor& remainders,
32693269
resizable_tensor& n_steps,
3270+
resizable_tensor& effective_weights,
32703271
long batch_size,
32713272
long seq_len,
32723273
long d_model,
@@ -3281,6 +3282,7 @@ namespace dlib
32813282
float* cum_halt = cumulative_halting.host();
32823283
float* remain = remainders.host();
32833284
float* steps = n_steps.host();
3285+
float* eff_weights = effective_weights.host();
32843286

32853287
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
32863288
if (cum_halt[pos] < halt_threshold) {
@@ -3294,6 +3296,7 @@ namespace dlib
32943296
cum_halt[pos] += effective;
32953297
remain[pos] -= effective;
32963298
steps[pos] = static_cast<float>(current_step + 1);
3299+
eff_weights[pos] += effective;
32973300

32983301
for (long c = 0; c < num_channels; ++c) {
32993302
for (long d = 0; d < d_model; ++d) {
@@ -3309,6 +3312,7 @@ namespace dlib
33093312
resizable_tensor& output,
33103313
const tensor& input_data,
33113314
const tensor& remainders,
3315+
resizable_tensor& effective_weights,
33123316
long batch_size,
33133317
long seq_len,
33143318
long d_model,
@@ -3318,13 +3322,16 @@ namespace dlib
33183322
const float* in_ptr = input_data.host();
33193323
const float* remain = remainders.host();
33203324
float* out_ptr = output.host();
3325+
float* eff_weights = effective_weights.host();
33213326

33223327
for (long pos = 0; pos < batch_size * seq_len; ++pos) {
33233328
float r = remain[pos];
33243329
if (r > 1e-6f) {
33253330
const long n = pos / seq_len;
33263331
const long s = pos % seq_len;
33273332

3333+
eff_weights[pos] += r;
3334+
33283335
for (long c = 0; c < num_channels; ++c) {
33293336
for (long d = 0; d < d_model; ++d) {
33303337
const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;

dlib/cuda/cpu_dlib.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ namespace dlib
555555
resizable_tensor& cumulative_halting,
556556
resizable_tensor& remainders,
557557
resizable_tensor& n_steps,
558+
resizable_tensor& effective_weights,
558559
long batch_size,
559560
long seq_len,
560561
long d_model,
@@ -567,6 +568,7 @@ namespace dlib
567568
resizable_tensor& output,
568569
const tensor& input_data,
569570
const tensor& remainders,
571+
resizable_tensor& effective_weights,
570572
long batch_size,
571573
long seq_len,
572574
long d_model,

dlib/cuda/cuda_dlib.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2819,6 +2819,7 @@ namespace dlib
28192819
float* cumulative_halting,
28202820
float* remainders,
28212821
float* n_steps,
2822+
float* effective_weights,
28222823
size_t batch_size,
28232824
size_t seq_len,
28242825
size_t d_model,
@@ -2841,6 +2842,7 @@ namespace dlib
28412842
cumulative_halting[pos] += effective;
28422843
remainders[pos] -= effective;
28432844
n_steps[pos] = static_cast<float>(current_step + 1);
2845+
effective_weights[pos] += effective;
28442846

28452847
for (size_t c = 0; c < num_channels; ++c) {
28462848
for (size_t d = 0; d < d_model; ++d) {
@@ -2859,6 +2861,7 @@ namespace dlib
28592861
resizable_tensor& cumulative_halting,
28602862
resizable_tensor& remainders,
28612863
resizable_tensor& n_steps,
2864+
resizable_tensor& effective_weights,
28622865
long batch_size,
28632866
long seq_len,
28642867
long d_model,
@@ -2877,6 +2880,7 @@ namespace dlib
28772880
cumulative_halting.device(),
28782881
remainders.device(),
28792882
n_steps.device(),
2883+
effective_weights.device(),
28802884
batch_size,
28812885
seq_len,
28822886
d_model,
@@ -2889,6 +2893,7 @@ namespace dlib
28892893
float* output,
28902894
const float* input_data,
28912895
const float* remainders,
2896+
float* effective_weights,
28922897
size_t batch_size,
28932898
size_t seq_len,
28942899
size_t d_model,
@@ -2902,6 +2907,8 @@ namespace dlib
29022907
const size_t n = pos / seq_len;
29032908
const size_t s = pos % seq_len;
29042909

2910+
effective_weights[pos] += r;
2911+
29052912
for (size_t c = 0; c < num_channels; ++c) {
29062913
for (size_t d = 0; d < d_model; ++d) {
29072914
const size_t idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
@@ -2916,6 +2923,7 @@ namespace dlib
29162923
resizable_tensor& output,
29172924
const tensor& input_data,
29182925
const tensor& remainders,
2926+
resizable_tensor& effective_weights,
29192927
long batch_size,
29202928
long seq_len,
29212929
long d_model,
@@ -2929,6 +2937,7 @@ namespace dlib
29292937
output.device(),
29302938
input_data.device(),
29312939
remainders.device(),
2940+
effective_weights.device(),
29322941
batch_size,
29332942
seq_len,
29342943
d_model,

dlib/cuda/cuda_dlib.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ namespace dlib
627627
resizable_tensor& cumulative_halting,
628628
resizable_tensor& remainders,
629629
resizable_tensor& n_steps,
630+
resizable_tensor& effective_weights,
630631
long batch_size,
631632
long seq_len,
632633
long d_model,
@@ -639,6 +640,7 @@ namespace dlib
639640
resizable_tensor& output,
640641
const tensor& input_data,
641642
const tensor& remainders,
643+
resizable_tensor& effective_weights,
642644
long batch_size,
643645
long seq_len,
644646
long d_model,

dlib/cuda/tensor_tools.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,7 @@ namespace dlib { namespace tt
14401440
resizable_tensor& cumulative_halting,
14411441
resizable_tensor& remainders,
14421442
resizable_tensor& n_steps,
1443+
resizable_tensor& effective_weights,
14431444
long batch_size,
14441445
long seq_len,
14451446
long d_model,
@@ -1450,28 +1451,29 @@ namespace dlib { namespace tt
14501451
{
14511452
#ifdef DLIB_USE_CUDA
14521453
cuda::update_act_state(output, input_data, halt_probs, cumulative_halting, remainders,
1453-
n_steps, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
1454+
n_steps, effective_weights, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
14541455
#else
14551456
cpu::update_act_state(output, input_data, halt_probs, cumulative_halting, remainders,
1456-
n_steps, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
1457+
n_steps, effective_weights, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
14571458
#endif
14581459
}
14591460

14601461
void finalize_act_output(
14611462
resizable_tensor& output,
14621463
const tensor& input_data,
14631464
const tensor& remainders,
1465+
resizable_tensor& effective_weights,
14641466
long batch_size,
14651467
long seq_len,
14661468
long d_model,
14671469
long num_channels
14681470
)
14691471
{
14701472
#ifdef DLIB_USE_CUDA
1471-
cuda::finalize_act_output(output, input_data, remainders,
1473+
cuda::finalize_act_output(output, input_data, remainders, effective_weights,
14721474
batch_size, seq_len, d_model, num_channels);
14731475
#else
1474-
cpu::finalize_act_output(output, input_data, remainders,
1476+
cpu::finalize_act_output(output, input_data, remainders, effective_weights,
14751477
batch_size, seq_len, d_model, num_channels);
14761478
#endif
14771479
}

dlib/cuda/tensor_tools.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,6 +2428,7 @@ namespace dlib { namespace tt
24282428
resizable_tensor& cumulative_halting,
24292429
resizable_tensor& remainders,
24302430
resizable_tensor& n_steps,
2431+
resizable_tensor& effective_weights,
24312432
long batch_size,
24322433
long seq_len,
24332434
long d_model,
@@ -2445,12 +2446,12 @@ namespace dlib { namespace tt
24452446
- input_data.nc() == d_model
24462447
- output has the same dimensions as input_data
24472448
- halt_probs.size() == batch_size * seq_len
2448-
- cumulative_halting.size() == remainders.size() == n_steps.size() == batch_size * seq_len
2449+
- cumulative_halting.size() == remainders.size() == n_steps.size() == effective_weights.size() == batch_size * seq_len
24492450
ensures
24502451
- Core ACT update step that accumulates weighted outputs:
24512452
- Updates ACT state for all positions
24522453
- Accumulates weighted outputs: output += α_t^n * input_data
2453-
- Updates cumulative_halting, remainders, and n_steps
2454+
- Updates cumulative_halting, remainders, n_steps, and effective_weights
24542455
- batch_size: number of samples in the batch
24552456
- seq_len: sequence length (number of positions to process)
24562457
- d_model: model dimension per channel
@@ -2463,6 +2464,7 @@ namespace dlib { namespace tt
24632464
resizable_tensor& output,
24642465
const tensor& input_data,
24652466
const tensor& remainders,
2467+
resizable_tensor& effective_weights,
24662468
long batch_size,
24672469
long seq_len,
24682470
long d_model,
@@ -2475,10 +2477,11 @@ namespace dlib { namespace tt
24752477
- input_data.nr() == seq_len
24762478
- input_data.nc() == d_model
24772479
- output has the same dimensions as input_data
2478-
- remainders.size() == batch_size * seq_len
2480+
- remainders.size() == effective_weights.size() == batch_size * seq_len
24792481
ensures
24802482
- Finalizes ACT output by adding remainder contributions:
24812483
- Adds final remainder contributions: output += ρ_t * input_data
2484+
- Updates effective_weights with remainder values
24822485
- Applied only to positions with significant remainder (> 1e-6)
24832486
- batch_size: number of samples in the batch
24842487
- seq_len: sequence length (number of positions to process)

dlib/dnn/layers.h

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5857,29 +5857,11 @@ namespace dlib
58575857
halting_probs_, logits_, input, params,
58585858
batch_size_, seq_len_, feature_dim_);
58595859

5860-
// Capture effective weights before state update
5861-
const float* p_halt = halting_probs_.host();
5862-
const float* cum_halt = cum_halt_ptr;
5863-
const float* remainders = remainders_ptr;
5864-
float* true_weights = true_effective_weights_.host();
5865-
5866-
for (long pos = 0; pos < total_positions; ++pos) {
5867-
if (cum_halt[pos] < halt_threshold_) {
5868-
float p = p_halt[pos];
5869-
float r = remainders[pos];
5870-
5871-
// Compute effective weight: alpha_t^n = min(p * rho, theta - h_t^(n-1))
5872-
float effective = std::min(p * r, halt_threshold_ - cum_halt[pos]);
5873-
5874-
// Store for backward pass
5875-
true_weights[pos] += effective;
5876-
}
5877-
}
5878-
58795860
// Update ACT state and accumulate weighted outputs
58805861
tt::update_act_state(
58815862
output, input, halting_probs_,
58825863
cumulative_halting_, remainders_, n_steps_,
5864+
true_effective_weights_,
58835865
batch_size_, seq_len_, d_model_, num_channels_,
58845866
halt_threshold_, step
58855867
);
@@ -5891,17 +5873,9 @@ namespace dlib
58915873
// Finalize with remainder contributions
58925874
tt::finalize_act_output(
58935875
output, input, remainders_,
5876+
true_effective_weights_,
58945877
batch_size_, seq_len_, d_model_, num_channels_);
58955878

5896-
// Add remainder weights for gradient computation
5897-
const float* final_remainders = remainders_.host();
5898-
float* true_weights = true_effective_weights_.host();
5899-
for (long pos = 0; pos < total_positions; ++pos) {
5900-
if (final_remainders[pos] > 1e-6f) {
5901-
true_weights[pos] += final_remainders[pos];
5902-
}
5903-
}
5904-
59055879
// Compute statistics for monitoring and regularization
59065880
compute_ponder_stats();
59075881
}

0 commit comments

Comments
 (0)