|
| 1 | +// Copyright 2025-present the zvec project |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include <cstddef> |
| 18 | +#include <cstdint> |
| 19 | +#include <cstring> |
| 20 | +#include <string> |
| 21 | +#include <vector> |
| 22 | +#include <rocksdb/db.h> |
| 23 | +#include <zvec/db/common/rocksdb_context.h> |
| 24 | +#include <zvec/db/status.h> |
| 25 | + |
| 26 | +namespace zvec { |
| 27 | + |
| 28 | +/*! Persistent vector storage backed by RocksDB. |
| 29 | + * |
| 30 | + * Stores raw float vectors and optional PQ codes in separate column families. |
| 31 | + * Designed for integration with the Metal GPU backend: vectors are loaded from |
| 32 | + * RocksDB into contiguous buffers that can be transferred to GPU memory. |
| 33 | + * |
| 34 | + * Column families: |
| 35 | + * - "vectors" : raw float vectors keyed by uint64 ID |
| 36 | + * - "pq_codes" : PQ-encoded codes keyed by uint64 ID |
| 37 | + * - "metadata" : dimension, count, and quantizer parameters |
| 38 | + */ |
| 39 | +class VectorStorage { |
| 40 | + public: |
| 41 | + VectorStorage() = default; |
| 42 | + ~VectorStorage() { close(); } |
| 43 | + |
| 44 | + //! Create a new vector storage at the given path. |
| 45 | + //! @param path Directory for the RocksDB instance. |
| 46 | + //! @param dim Vector dimension. |
| 47 | + Status create(const std::string &path, size_t dim) { |
| 48 | + dim_ = dim; |
| 49 | + Status s = ctx_.create(path, {"default", "vectors", "pq_codes", "metadata"}); |
| 50 | + if (!s.ok()) return s; |
| 51 | + return write_metadata(); |
| 52 | + } |
| 53 | + |
| 54 | + //! Open an existing vector storage. |
| 55 | + //! @param path Directory of the RocksDB instance. |
| 56 | + //! @param read_only Open in read-only mode. |
| 57 | + Status open(const std::string &path, bool read_only = false) { |
| 58 | + Status s = ctx_.open(path, {"default", "vectors", "pq_codes", "metadata"}, |
| 59 | + read_only); |
| 60 | + if (!s.ok()) return s; |
| 61 | + return read_metadata(); |
| 62 | + } |
| 63 | + |
| 64 | + //! Close the storage. |
| 65 | + Status close() { return ctx_.close(); } |
| 66 | + |
| 67 | + //! Retrieve the vector dimension. |
| 68 | + size_t dim() const { return dim_; } |
| 69 | + |
| 70 | + //! Retrieve the number of stored vectors. |
| 71 | + size_t count() const { return count_; } |
| 72 | + |
| 73 | + //! Insert a vector with the given ID. |
| 74 | + //! @param id Unique vector identifier. |
| 75 | + //! @param vec Float vector of size dim(). |
| 76 | + Status put_vector(uint64_t id, const float *vec) { |
| 77 | + auto *cf = ctx_.get_cf("vectors"); |
| 78 | + if (!cf) return Status::InternalError("vectors CF not found"); |
| 79 | + |
| 80 | + rocksdb::Slice key(reinterpret_cast<const char *>(&id), sizeof(id)); |
| 81 | + rocksdb::Slice value(reinterpret_cast<const char *>(vec), |
| 82 | + dim_ * sizeof(float)); |
| 83 | + auto s = ctx_.db_->Put(ctx_.write_opts_, cf, key, value); |
| 84 | + if (!s.ok()) return Status::InternalError(s.ToString()); |
| 85 | + |
| 86 | + ++count_; |
| 87 | + return Status::OK(); |
| 88 | + } |
| 89 | + |
| 90 | + //! Retrieve a vector by ID. |
| 91 | + //! @param id Vector identifier. |
| 92 | + //! @param vec Output buffer of size dim(). |
| 93 | + //! @return OK if found, NotFound otherwise. |
| 94 | + Status get_vector(uint64_t id, float *vec) const { |
| 95 | + auto *cf = ctx_.get_cf("vectors"); |
| 96 | + if (!cf) return Status::InternalError("vectors CF not found"); |
| 97 | + |
| 98 | + rocksdb::Slice key(reinterpret_cast<const char *>(&id), sizeof(id)); |
| 99 | + std::string value; |
| 100 | + auto s = ctx_.db_->Get(ctx_.read_opts_, cf, key, &value); |
| 101 | + if (!s.ok()) return Status::NotFound("vector not found"); |
| 102 | + |
| 103 | + if (value.size() != dim_ * sizeof(float)) { |
| 104 | + return Status::InternalError("corrupted vector data"); |
| 105 | + } |
| 106 | + std::memcpy(vec, value.data(), dim_ * sizeof(float)); |
| 107 | + return Status::OK(); |
| 108 | + } |
| 109 | + |
| 110 | + //! Store PQ codes for a vector. |
| 111 | + //! @param id Vector identifier. |
| 112 | + //! @param codes PQ code array of size m. |
| 113 | + //! @param m Number of sub-quantizers. |
| 114 | + Status put_pq_codes(uint64_t id, const uint8_t *codes, size_t m) { |
| 115 | + auto *cf = ctx_.get_cf("pq_codes"); |
| 116 | + if (!cf) return Status::InternalError( "pq_codes CF not found"); |
| 117 | + |
| 118 | + rocksdb::Slice key(reinterpret_cast<const char *>(&id), sizeof(id)); |
| 119 | + rocksdb::Slice value(reinterpret_cast<const char *>(codes), m); |
| 120 | + auto s = ctx_.db_->Put(ctx_.write_opts_, cf, key, value); |
| 121 | + if (!s.ok()) return Status::InternalError( s.ToString()); |
| 122 | + |
| 123 | + return Status::OK(); |
| 124 | + } |
| 125 | + |
| 126 | + //! Retrieve PQ codes for a vector. |
| 127 | + //! @param id Vector identifier. |
| 128 | + //! @param codes Output buffer of size m. |
| 129 | + //! @param m Number of sub-quantizers. |
| 130 | + Status get_pq_codes(uint64_t id, uint8_t *codes, size_t m) const { |
| 131 | + auto *cf = ctx_.get_cf("pq_codes"); |
| 132 | + if (!cf) return Status::InternalError( "pq_codes CF not found"); |
| 133 | + |
| 134 | + rocksdb::Slice key(reinterpret_cast<const char *>(&id), sizeof(id)); |
| 135 | + std::string value; |
| 136 | + auto s = ctx_.db_->Get(ctx_.read_opts_, cf, key, &value); |
| 137 | + if (!s.ok()) return Status::NotFound( "pq codes not found"); |
| 138 | + |
| 139 | + size_t copy_len = std::min(m, value.size()); |
| 140 | + std::memcpy(codes, value.data(), copy_len); |
| 141 | + return Status::OK(); |
| 142 | + } |
| 143 | + |
| 144 | + //! Batch insert vectors. |
| 145 | + //! @param ids Array of n vector IDs. |
| 146 | + //! @param vecs Contiguous float array (n x dim), row-major. |
| 147 | + //! @param n Number of vectors. |
| 148 | + Status put_vectors_batch(const uint64_t *ids, const float *vecs, size_t n) { |
| 149 | + auto *cf = ctx_.get_cf("vectors"); |
| 150 | + if (!cf) return Status::InternalError( "vectors CF not found"); |
| 151 | + |
| 152 | + rocksdb::WriteBatch batch; |
| 153 | + for (size_t i = 0; i < n; ++i) { |
| 154 | + rocksdb::Slice key(reinterpret_cast<const char *>(&ids[i]), |
| 155 | + sizeof(uint64_t)); |
| 156 | + rocksdb::Slice value(reinterpret_cast<const char *>(vecs + i * dim_), |
| 157 | + dim_ * sizeof(float)); |
| 158 | + batch.Put(cf, key, value); |
| 159 | + } |
| 160 | + |
| 161 | + auto s = ctx_.db_->Write(ctx_.write_opts_, &batch); |
| 162 | + if (!s.ok()) return Status::InternalError( s.ToString()); |
| 163 | + |
| 164 | + count_ += n; |
| 165 | + return Status::OK(); |
| 166 | + } |
| 167 | + |
| 168 | + //! Load all vectors into a contiguous buffer for GPU transfer. |
| 169 | + //! Returns vectors in insertion order. The caller should allocate |
| 170 | + //! the output buffers. |
| 171 | + //! @param out_ids Output vector IDs (resized to count). |
| 172 | + //! @param out_vecs Output float buffer (count x dim), row-major. |
| 173 | + Status load_all(std::vector<uint64_t> &out_ids, |
| 174 | + std::vector<float> &out_vecs) const { |
| 175 | + auto *cf = ctx_.get_cf("vectors"); |
| 176 | + if (!cf) return Status::InternalError( "vectors CF not found"); |
| 177 | + |
| 178 | + out_ids.clear(); |
| 179 | + out_vecs.clear(); |
| 180 | + out_ids.reserve(count_); |
| 181 | + out_vecs.reserve(count_ * dim_); |
| 182 | + |
| 183 | + std::unique_ptr<rocksdb::Iterator> it( |
| 184 | + ctx_.db_->NewIterator(ctx_.read_opts_, cf)); |
| 185 | + for (it->SeekToFirst(); it->Valid(); it->Next()) { |
| 186 | + auto key = it->key(); |
| 187 | + auto value = it->value(); |
| 188 | + |
| 189 | + if (key.size() != sizeof(uint64_t) || |
| 190 | + value.size() != dim_ * sizeof(float)) { |
| 191 | + continue; |
| 192 | + } |
| 193 | + |
| 194 | + uint64_t id; |
| 195 | + std::memcpy(&id, key.data(), sizeof(uint64_t)); |
| 196 | + out_ids.push_back(id); |
| 197 | + |
| 198 | + size_t offset = out_vecs.size(); |
| 199 | + out_vecs.resize(offset + dim_); |
| 200 | + std::memcpy(out_vecs.data() + offset, value.data(), |
| 201 | + dim_ * sizeof(float)); |
| 202 | + } |
| 203 | + |
| 204 | + return Status::OK(); |
| 205 | + } |
| 206 | + |
| 207 | + //! Flush pending writes to disk. |
| 208 | + Status flush() { return ctx_.flush(); } |
| 209 | + |
| 210 | + private: |
| 211 | + VectorStorage(const VectorStorage &) = delete; |
| 212 | + VectorStorage &operator=(const VectorStorage &) = delete; |
| 213 | + |
| 214 | + Status write_metadata() { |
| 215 | + auto *cf = ctx_.get_cf("metadata"); |
| 216 | + if (!cf) return Status::InternalError( "metadata CF not found"); |
| 217 | + |
| 218 | + // Write dimension |
| 219 | + auto s = ctx_.db_->Put(ctx_.write_opts_, cf, "dim", |
| 220 | + rocksdb::Slice(reinterpret_cast<const char *>(&dim_), |
| 221 | + sizeof(dim_))); |
| 222 | + if (!s.ok()) return Status::InternalError( s.ToString()); |
| 223 | + |
| 224 | + return Status::OK(); |
| 225 | + } |
| 226 | + |
| 227 | + Status read_metadata() { |
| 228 | + auto *cf = ctx_.get_cf("metadata"); |
| 229 | + if (!cf) return Status::InternalError( "metadata CF not found"); |
| 230 | + |
| 231 | + std::string value; |
| 232 | + auto s = ctx_.db_->Get(ctx_.read_opts_, cf, "dim", &value); |
| 233 | + if (!s.ok()) return Status::InternalError( "missing dim"); |
| 234 | + if (value.size() != sizeof(size_t)) { |
| 235 | + return Status::InternalError( "corrupted dim"); |
| 236 | + } |
| 237 | + std::memcpy(&dim_, value.data(), sizeof(size_t)); |
| 238 | + |
| 239 | + // Count vectors |
| 240 | + auto *vec_cf = ctx_.get_cf("vectors"); |
| 241 | + if (vec_cf) { |
| 242 | + count_ = 0; |
| 243 | + std::unique_ptr<rocksdb::Iterator> it( |
| 244 | + ctx_.db_->NewIterator(ctx_.read_opts_, vec_cf)); |
| 245 | + for (it->SeekToFirst(); it->Valid(); it->Next()) { |
| 246 | + ++count_; |
| 247 | + } |
| 248 | + } |
| 249 | + |
| 250 | + return Status::OK(); |
| 251 | + } |
| 252 | + |
| 253 | + mutable RocksdbContext ctx_; |
| 254 | + size_t dim_{0}; |
| 255 | + size_t count_{0}; |
| 256 | +}; |
| 257 | + |
| 258 | +} // namespace zvec |
0 commit comments