Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CoreVersionSupportChecker : public IVersionSupportChecker
std::vector<std::shared_ptr<const IVersionSupportChecker>>& version_support_registry()
{
static std::vector<std::shared_ptr<const IVersionSupportChecker>> registry{
std::make_shared<CoreVersionSupportChecker>()};
std::make_shared<CoreVersionSupportChecker>()};
return registry;
}

Expand Down Expand Up @@ -123,8 +123,8 @@ void verify_config_version(const std::string versionStr)
if (support == Supported::PARTIAL)
{
std::stringstream ss;
std::cerr << "Model config is a partially-supported version " << versionStr
<< ". Continuing with partial support." << std::endl;
std::cerr << "Model config is a partially-supported version " << versionStr << ". Continuing with partial support."
<< std::endl;
}
}

Expand Down
147 changes: 67 additions & 80 deletions NAM/wavenet/a2_fast.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
#if defined(NAM_ENABLE_A2_FAST)

// Ring-buffer strategy:
// 0 = linear memmove-rewind (variable worst-case latency, sporadic spikes)
// 1 = pow2 + tail mirror (constant per-block work, branchless reads)
// Controlled externally with -DNAM_A2_RING_MODE=0 for head-to-head comparison.
#ifndef NAM_A2_RING_MODE
#define NAM_A2_RING_MODE 1
#endif

#include "a2_fast.h"

#include <algorithm>
#include <array>
#include <cmath>
#include <cstddef>
#include <cstring>
#include <iterator>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>

#include <Eigen/Dense>

#include "../dsp.h"
// Ring-buffer strategy:
// 0 = linear memmove-rewind (variable worst-case latency, sporadic spikes)
// 1 = pow2 + tail mirror (constant per-block work, branchless reads)
// Controlled externally with -DNAM_A2_RING_MODE=0 for head-to-head comparison.
#ifndef NAM_A2_RING_MODE
#define NAM_A2_RING_MODE 1
#endif

#include "a2_fast.h"

#include <algorithm>
#include <array>
#include <cmath>
#include <cstddef>
#include <cstring>
#include <iterator>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>

#include <Eigen/Dense>

#include "../dsp.h"

namespace nam
{
Expand Down Expand Up @@ -92,20 +92,20 @@ class A2FastModel : public DSP

// Conv1D input history ring buffer, column-major (Channels rows).
std::vector<float> history;
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
// pow2 ring + tail mirror. Storage = (pow2_size + max_buffer_size) cols.
// write_pos is kept in [0, pow2_size), reads use (pos & pow2_mask) and are
// always contiguous because cols [pow2_size, pow2_size + max_buffer_size)
// mirror cols [0, max_buffer_size).
int pow2_size = 0;
int pow2_mask = 0;
int write_pos = 0;
#else
#else
// Linear ring with sporadic memmove-rewind. history_cols = 2*max_lookback +
// max_buffer_size; write_pos grows monotonically until rewind fires.
int history_cols = 0;
int write_pos = 0;
#endif
#endif
};

std::array<Layer, kNumLayers> _layers;
Expand All @@ -124,20 +124,20 @@ class A2FastModel : public DSP

// Head ring buffer (Channels rows, col-major). Same ring layout as per-layer.
std::vector<float> _head_history;
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
int _head_pow2_size = 0;
int _head_pow2_mask = 0;
int _head_write_pos = 0;
#else
#else
int _head_history_cols = 0;
int _head_write_pos = 0;
#endif
#endif

// Working buffers (all Channels rows, max_buffer_size cols, col-major).
std::vector<float> _layer_in; // current layer input / next layer input (in-place residual)
std::vector<float> _head_sum; // accumulates activations across all layers
std::vector<float> _z; // per-layer conv output accumulator (tap-major)
std::vector<float> _cond; // float32 copy of the double NAM_SAMPLE input, reused each block
std::vector<float> _z; // per-layer conv output accumulator (tap-major)
std::vector<float> _cond; // float32 copy of the double NAM_SAMPLE input, reused each block
std::vector<float> _head_out; // float32 head output before writing to NAM_SAMPLE

int _prewarm_samples = 0;
Expand Down Expand Up @@ -300,29 +300,29 @@ void A2FastModel<Channels>::SetMaxBufferSize(int maxBufferSize)

for (auto& L : _layers)
{
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
L.pow2_size = next_pow2(L.max_lookback + maxBufferSize);
L.pow2_mask = L.pow2_size - 1;
L.history.assign(static_cast<size_t>(Channels) * (L.pow2_size + maxBufferSize), 0.0f);
L.write_pos = L.max_lookback;
#else
#else
L.history_cols = 2 * L.max_lookback + maxBufferSize;
L.history.assign(static_cast<size_t>(Channels) * L.history_cols, 0.0f);
L.write_pos = L.max_lookback;
#endif
#endif
}

const int head_lookback = kHeadKernelSize - 1;
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
_head_pow2_size = next_pow2(head_lookback + maxBufferSize);
_head_pow2_mask = _head_pow2_size - 1;
_head_history.assign(static_cast<size_t>(Channels) * (_head_pow2_size + maxBufferSize), 0.0f);
_head_write_pos = head_lookback;
#else
#else
_head_history_cols = 2 * head_lookback + maxBufferSize;
_head_history.assign(static_cast<size_t>(Channels) * _head_history_cols, 0.0f);
_head_write_pos = head_lookback;
#endif
#endif
}

// -----------------------------------------------------------------------------
Expand All @@ -336,23 +336,22 @@ void A2FastModel<Channels>::SetMaxBufferSize(int maxBufferSize)
template <int Channels>
void A2FastModel<Channels>::_ring_write(Layer& L, int num_frames)
{
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
const int mbs = GetMaxBufferSize();
float* const hist = L.history.data();
const float* const src = _layer_in.data();
const int wp = L.write_pos;
const int first = std::min(num_frames, L.pow2_size - wp);
std::memcpy(hist + static_cast<size_t>(wp) * Channels, src,
static_cast<size_t>(first) * Channels * sizeof(float));
std::memcpy(hist + static_cast<size_t>(wp) * Channels, src, static_cast<size_t>(first) * Channels * sizeof(float));
if (first < num_frames)
{
std::memcpy(hist, src + static_cast<size_t>(first) * Channels,
static_cast<size_t>(num_frames - first) * Channels * sizeof(float));
}
std::memcpy(hist + static_cast<size_t>(L.pow2_size) * Channels, hist,
static_cast<size_t>(mbs) * Channels * sizeof(float));
std::memcpy(
hist + static_cast<size_t>(L.pow2_size) * Channels, hist, static_cast<size_t>(mbs) * Channels * sizeof(float));
L.write_pos = (wp + num_frames) & L.pow2_mask;
#else
#else
if (L.write_pos + num_frames > L.history_cols)
{
const int keep = L.max_lookback;
Expand All @@ -363,29 +362,28 @@ void A2FastModel<Channels>::_ring_write(Layer& L, int num_frames)
std::memcpy(L.history.data() + static_cast<size_t>(L.write_pos) * Channels, _layer_in.data(),
static_cast<size_t>(num_frames) * Channels * sizeof(float));
L.write_pos += num_frames;
#endif
#endif
}

template <int Channels>
void A2FastModel<Channels>::_head_ring_write(int num_frames)
{
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
const int mbs = GetMaxBufferSize();
float* const hist = _head_history.data();
const float* const src = _head_sum.data();
const int wp = _head_write_pos;
const int first = std::min(num_frames, _head_pow2_size - wp);
std::memcpy(hist + static_cast<size_t>(wp) * Channels, src,
static_cast<size_t>(first) * Channels * sizeof(float));
std::memcpy(hist + static_cast<size_t>(wp) * Channels, src, static_cast<size_t>(first) * Channels * sizeof(float));
if (first < num_frames)
{
std::memcpy(hist, src + static_cast<size_t>(first) * Channels,
static_cast<size_t>(num_frames - first) * Channels * sizeof(float));
}
std::memcpy(hist + static_cast<size_t>(_head_pow2_size) * Channels, hist,
static_cast<size_t>(mbs) * Channels * sizeof(float));
std::memcpy(
hist + static_cast<size_t>(_head_pow2_size) * Channels, hist, static_cast<size_t>(mbs) * Channels * sizeof(float));
_head_write_pos = (wp + num_frames) & _head_pow2_mask;
#else
#else
const int keep = kHeadKernelSize - 1;
if (_head_write_pos + num_frames > _head_history_cols)
{
Expand All @@ -396,7 +394,7 @@ void A2FastModel<Channels>::_head_ring_write(int num_frames)
std::memcpy(_head_history.data() + static_cast<size_t>(_head_write_pos) * Channels, _head_sum.data(),
static_cast<size_t>(num_frames) * Channels * sizeof(float));
_head_write_pos += num_frames;
#endif
#endif
}

// -----------------------------------------------------------------------------
Expand All @@ -418,15 +416,13 @@ void A2FastModel<Channels>::_layer_forward_k(Layer& L, const float* cond, int nu
// D` samples into the past. In pow2 mode the position is wrapped by mask and
// reads spanning the wrap land in the tail mirror; in linear mode write_pos
// is monotonic and arithmetic is plain.
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
const int mask = L.pow2_mask;
auto tap_base_phys = [&](int taps_back) {
return (L.write_pos - num_frames - taps_back * D) & mask;
};
#else
auto tap_base_phys = [&](int taps_back) { return (L.write_pos - num_frames - taps_back * D) & mask; };
#else
const int base = L.write_pos - num_frames;
auto tap_base_phys = [&](int taps_back) { return base - taps_back * D; };
#endif
#endif

// Two conv strategies, dispatched at compile time on Channels:
//
Expand Down Expand Up @@ -598,10 +594,10 @@ void A2FastModel<Channels>::_layer_forward_k(Layer& L, const float* cond, int nu

// Post-conv: bias, mixin, LeakyReLU, head_sum, 1x1 residual — all block ops.
ztile.colwise() += conv_b_vec;
ztile.noalias() += mixin_vec * cond_row; // rank-1 outer product
ztile.noalias() += mixin_vec * cond_row; // rank-1 outer product
ztile = (ztile.array() < 0.0f).select(ztile.array() * kLeakySlope, ztile.array());
hsum_block += ztile;
lin_block.noalias() += l1x1_mat * ztile; // 8x8 × 8xN GEMM
lin_block.noalias() += l1x1_mat * ztile; // 8x8 × 8xN GEMM
lin_block.colwise() += l1x1_b_vec;
}
}
Expand All @@ -616,15 +612,9 @@ void A2FastModel<Channels>::_layer_forward(int layer_idx, const float* cond, int
_ring_write(L, num_frames);
switch (L.kernel_size)
{
case 6:
_layer_forward_k<6>(L, cond, num_frames);
break;
case 15:
_layer_forward_k<15>(L, cond, num_frames);
break;
default:
throw std::runtime_error("A2FastModel: unexpected kernel_size "
+ std::to_string(L.kernel_size));
case 6: _layer_forward_k<6>(L, cond, num_frames); break;
case 15: _layer_forward_k<15>(L, cond, num_frames); break;
default: throw std::runtime_error("A2FastModel: unexpected kernel_size " + std::to_string(L.kernel_size));
}
}

Expand All @@ -635,15 +625,13 @@ template <int Channels>
void A2FastModel<Channels>::_head_forward(float* output, int num_frames)
{
_head_ring_write(num_frames);
#if NAM_A2_RING_MODE == 1
#if NAM_A2_RING_MODE == 1
const int mask = _head_pow2_mask;
auto col_of = [&](int f, int k) {
return (_head_write_pos - num_frames + f - (kHeadKernelSize - 1 - k)) & mask;
};
#else
auto col_of = [&](int f, int k) { return (_head_write_pos - num_frames + f - (kHeadKernelSize - 1 - k)) & mask; };
#else
const int base = _head_write_pos - num_frames;
auto col_of = [&](int f, int k) { return base + f - (kHeadKernelSize - 1 - k); };
#endif
#endif

for (int f = 0; f < num_frames; f++)
{
Expand Down Expand Up @@ -874,9 +862,8 @@ bool is_a2_shape(const nlohmann::json& config, int* channels)
return false;

// No FiLM anywhere
for (const char* key :
{"conv_pre_film", "conv_post_film", "input_mixin_pre_film", "input_mixin_post_film", "activation_pre_film",
"activation_post_film", "layer1x1_post_film", "head1x1_post_film"})
for (const char* key : {"conv_pre_film", "conv_post_film", "input_mixin_pre_film", "input_mixin_post_film",
"activation_pre_film", "activation_post_film", "layer1x1_post_film", "head1x1_post_film"})
{
if (!film_inactive(la, key))
return false;
Expand Down
8 changes: 4 additions & 4 deletions NAM/wavenet/a2_fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

#if defined(NAM_ENABLE_A2_FAST)

#include <array>
#include <memory>
#include <array>
#include <memory>

#include "../model_config.h"
#include "json.hpp"
#include "../model_config.h"
#include "json.hpp"

namespace nam
{
Expand Down
2 changes: 1 addition & 1 deletion NAM/wavenet/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "model.h"

#if defined(NAM_ENABLE_A2_FAST)
#include "a2_fast.h"
#include "a2_fast.h"
#endif

// detail::Head (WaveNet post-stack head) =====================================
Expand Down
3 changes: 1 addition & 2 deletions NAM/wavenet/slimmable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ void SlimmableWavenet::_pending_store_release(std::shared_ptr<StagedSlimModel> p

std::shared_ptr<SlimmableWavenet::StagedSlimModel> SlimmableWavenet::_pending_exchange_take_acq_rel()
{
return std::atomic_exchange_explicit(
&_pending_staged, std::shared_ptr<StagedSlimModel>{}, std::memory_order_acq_rel);
return std::atomic_exchange_explicit(&_pending_staged, std::shared_ptr<StagedSlimModel>{}, std::memory_order_acq_rel);
}
#else
void SlimmableWavenet::_pending_clear_release()
Expand Down
2 changes: 1 addition & 1 deletion NAM/wavenet/slimmable.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#ifdef _LIBCPP_VERSION
// libc++: std::atomic<std::shared_ptr<T>> is not viable; staging uses deprecated atomic_* free functions.
#else
#include <atomic>
#include <atomic>
#endif

#include "../dsp.h"
Expand Down
Loading
Loading