Skip to content

Commit 7a37eb4

Browse files
Pagination beyond int64
1 parent 4bc2247 commit 7a37eb4

28 files changed

Lines changed: 1920 additions & 280 deletions

batch_writer.go

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
sql "github.com/Shopify/ghostferry/sqlwrapper"
99

10+
"github.com/go-mysql-org/go-mysql/schema"
1011
"github.com/sirupsen/logrus"
1112
)
1213

@@ -56,14 +57,65 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
5657
return nil
5758
}
5859

59-
startPaginationKeypos, err := values[0].GetUint64(batch.PaginationKeyIndex())
60-
if err != nil {
61-
return err
62-
}
60+
var startPaginationKeypos, endPaginationKeypos PaginationKey
61+
var err error
62+
63+
paginationColumn := batch.TableSchema().GetPaginationColumn()
6364

64-
endPaginationKeypos, err := values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
65-
if err != nil {
66-
return err
65+
switch paginationColumn.Type {
66+
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
67+
var startValue, endValue uint64
68+
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
69+
if err != nil {
70+
return err
71+
}
72+
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
73+
if err != nil {
74+
return err
75+
}
76+
startPaginationKeypos = NewUint64Key(startValue)
77+
endPaginationKeypos = NewUint64Key(endValue)
78+
79+
case schema.TYPE_BINARY, schema.TYPE_STRING:
80+
startValueInterface := values[0][batch.PaginationKeyIndex()]
81+
endValueInterface := values[len(values)-1][batch.PaginationKeyIndex()]
82+
83+
getBytes := func(val interface{}) ([]byte, error) {
84+
switch v := val.(type) {
85+
case []byte:
86+
return v, nil
87+
case string:
88+
return []byte(v), nil
89+
default:
90+
return nil, fmt.Errorf("expected binary/string pagination key, got %T", val)
91+
}
92+
}
93+
94+
startValue, err := getBytes(startValueInterface)
95+
if err != nil {
96+
return err
97+
}
98+
99+
endValue, err := getBytes(endValueInterface)
100+
if err != nil {
101+
return err
102+
}
103+
104+
startPaginationKeypos = NewBinaryKey(startValue)
105+
endPaginationKeypos = NewBinaryKey(endValue)
106+
107+
default:
108+
var startValue, endValue uint64
109+
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
110+
if err != nil {
111+
return err
112+
}
113+
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
114+
if err != nil {
115+
return err
116+
}
117+
startPaginationKeypos = NewUint64Key(startValue)
118+
endPaginationKeypos = NewUint64Key(endValue)
67119
}
68120

69121
db := batch.TableSchema().Schema
@@ -78,12 +130,12 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
78130

79131
query, args, err := batch.AsSQLQuery(db, table)
80132
if err != nil {
81-
return fmt.Errorf("during generating sql query at paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, err)
133+
return fmt.Errorf("during generating sql query at paginationKey %s -> %s: %v", startPaginationKeypos.String(), endPaginationKeypos.String(), err)
82134
}
83-
135+
84136
stmt, err := w.stmtCache.StmtFor(w.DB, query)
85137
if err != nil {
86-
return fmt.Errorf("during prepare query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
138+
return fmt.Errorf("during prepare query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
87139
}
88140

89141
tx, err := w.DB.Begin()
@@ -94,14 +146,14 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
94146
_, err = tx.Stmt(stmt).Exec(args...)
95147
if err != nil {
96148
tx.Rollback()
97-
return fmt.Errorf("during exec query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
149+
return fmt.Errorf("during exec query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
98150
}
99151

100152
if w.InlineVerifier != nil {
101153
mismatches, err := w.InlineVerifier.CheckFingerprintInline(tx, db, table, batch, w.EnforceInlineVerification)
102154
if err != nil {
103155
tx.Rollback()
104-
return fmt.Errorf("during fingerprint checking for paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
156+
return fmt.Errorf("during fingerprint checking for paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
105157
}
106158

107159
if w.EnforceInlineVerification {
@@ -119,7 +171,7 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
119171
err = tx.Commit()
120172
if err != nil {
121173
tx.Rollback()
122-
return fmt.Errorf("during commit near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
174+
return fmt.Errorf("during commit near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
123175
}
124176

125177
// Note that the state tracker expects us the track based on the original

compression_verifier.go

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ func (e UnsupportedCompressionError) Error() string {
4949
type CompressionVerifier struct {
5050
logger *logrus.Entry
5151

52+
TableSchemaCache TableSchemaCache
5253
supportedAlgorithms map[string]struct{}
5354
tableColumnCompressions TableColumnCompressionConfig
5455
}
@@ -59,32 +60,52 @@ type CompressionVerifier struct {
5960
// The GetCompressedHashes method checks if the existing table contains compressed data
6061
// and will apply the decompression algorithm to the applicable columns if necessary.
6162
// After the columns are decompressed, the hashes of the data are used to verify equality
62-
func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (map[uint64][]byte, error) {
63+
func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schemaName, tableName, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (map[string][]byte, error) {
6364
c.logger.WithFields(logrus.Fields{
6465
"tag": "compression_verifier",
65-
"table": table,
66+
"table": tableName,
6667
}).Info("decompressing table data before verification")
6768

68-
tableCompression := c.tableColumnCompressions[table]
69+
tableCompression := c.tableColumnCompressions[tableName]
6970

7071
// Extract the raw rows using SQL to be decompressed
71-
rows, err := getRows(db, schema, table, paginationKeyColumn, columns, paginationKeys)
72+
rows, err := getRows(db, schemaName, tableName, paginationKeyColumn, columns, paginationKeys)
7273
if err != nil {
7374
return nil, err
7475
}
7576
defer rows.Close()
7677

77-
// Decompress applicable columns and hash the resulting column values for comparison
78-
resultSet := make(map[uint64][]byte)
78+
table := c.TableSchemaCache.Get(schemaName, tableName)
79+
if table == nil {
80+
return nil, fmt.Errorf("table %s.%s not found in schema cache", schemaName, tableName)
81+
}
82+
paginationColumn := table.GetPaginationColumn()
83+
resultSet := make(map[string][]byte)
84+
7985
for rows.Next() {
8086
rowData, err := ScanByteRow(rows, len(columns)+1)
8187
if err != nil {
8288
return nil, err
8389
}
8490

85-
paginationKey, err := strconv.ParseUint(string(rowData[0]), 10, 64)
86-
if err != nil {
87-
return nil, err
91+
var paginationKeyStr string
92+
switch paginationColumn.Type {
93+
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
94+
paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64)
95+
if err != nil {
96+
return nil, err
97+
}
98+
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
99+
100+
case schema.TYPE_BINARY, schema.TYPE_STRING:
101+
paginationKeyStr = NewBinaryKey(rowData[0]).String()
102+
103+
default:
104+
paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64)
105+
if err != nil {
106+
return nil, err
107+
}
108+
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
88109
}
89110

90111
// Decompress the applicable columns and then hash them together
@@ -95,7 +116,7 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag
95116
for idx, column := range columns {
96117
if algorithm, ok := tableCompression[column.Name]; ok {
97118
// rowData contains the result of "SELECT paginationKeyColumn, * FROM ...", so idx+1 to get each column
98-
decompressedColData, err := c.Decompress(table, column.Name, algorithm, rowData[idx+1])
119+
decompressedColData, err := c.Decompress(tableName, column.Name, algorithm, rowData[idx+1])
99120
if err != nil {
100121
return nil, err
101122
}
@@ -111,20 +132,20 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag
111132
return nil, err
112133
}
113134

114-
resultSet[paginationKey] = decompressedRowHash
135+
resultSet[paginationKeyStr] = decompressedRowHash
115136
}
116137

117138
metrics.Gauge(
118139
"compression_verifier_decompress_rows",
119140
float64(len(resultSet)),
120-
[]MetricTag{{"table", table}},
141+
[]MetricTag{{"table", tableName}},
121142
1.0,
122143
)
123144

124145
logrus.WithFields(logrus.Fields{
125146
"tag": "compression_verifier",
126147
"rows": len(resultSet),
127-
"table": table,
148+
"table": tableName,
128149
}).Debug("decompressed rows will be compared")
129150

130151
return resultSet, nil
@@ -192,12 +213,13 @@ func (c *CompressionVerifier) verifyConfiguredCompression(tableColumnCompression
192213

193214
// NewCompressionVerifier first checks the map for supported compression algorithms before
194215
// initializing and returning the initialized instance.
195-
func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig) (*CompressionVerifier, error) {
216+
func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig, tableSchemaCache TableSchemaCache) (*CompressionVerifier, error) {
196217
supportedAlgorithms := make(map[string]struct{})
197218
supportedAlgorithms[CompressionSnappy] = struct{}{}
198219

199220
compressionVerifier := &CompressionVerifier{
200221
logger: logrus.WithField("tag", "compression_verifier"),
222+
TableSchemaCache: tableSchemaCache,
201223
supportedAlgorithms: supportedAlgorithms,
202224
tableColumnCompressions: tableColumnCompressions,
203225
}
@@ -209,7 +231,7 @@ func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig
209231
return compressionVerifier, nil
210232
}
211233

212-
func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (*sqlorig.Rows, error) {
234+
func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (*sqlorig.Rows, error) {
213235
quotedPaginationKey := QuoteField(paginationKeyColumn)
214236
sql, args, err := rowSelector(columns, paginationKeyColumn).
215237
From(QuotedTableNameFromString(schema, table)).

config.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,17 @@ func (c ForceIndexConfig) IndexFor(schemaName, tableName string) string {
376376
// CascadingPaginationColumnConfig to configure pagination columns to be
377377
// used. The term `Cascading` to denote that greater specificity takes
378378
// precedence.
379+
//
380+
// IMPORTANT: All configured pagination columns must contain unique values.
381+
// When specifying a FallbackColumn for tables with composite primary keys,
382+
// ensure the column has a unique constraint to prevent data loss during migration.
379383
type CascadingPaginationColumnConfig struct {
380384
// PerTable has greatest specificity and takes precedence over the other options
381385
PerTable map[string]map[string]string // SchemaName => TableName => ColumnName
382386

383387
// FallbackColumn is a global default to fallback to and is less specific than the
384-
// default, which is the Primary Key
388+
// default, which is the Primary Key.
389+
// This column MUST have unique values (ideally a unique constraint) for data integrity.
385390
FallbackColumn string
386391
}
387392

@@ -727,10 +732,15 @@ type Config struct {
727732
//
728733
ForceIndexForVerification ForceIndexConfig
729734

730-
// Ghostferry requires a single numeric column to paginate over tables. Inferring that column is done in the following exact order:
735+
// Ghostferry requires a single numeric or binary column to paginate over tables. Inferring that column is done in the following exact order:
731736
// 1. Use the PerTable pagination column, if configured for a table. Fail if we cannot find this column in the table.
732-
// 2. Use the table's primary key column as the pagination column. Fail if the primary key is not numeric or is a composite key without a FallbackColumn specified.
737+
// 2. Use the table's primary key column as the pagination column. Fail if the primary key is not numeric/binary or is a composite key without a FallbackColumn specified.
733738
// 3. Use the FallbackColumn pagination column, if configured. Fail if we cannot find this column in the table.
739+
//
740+
// IMPORTANT: The pagination column MUST contain unique values for data integrity.
741+
// When using a FallbackColumn (typically "id") for tables with composite primary keys, this column must have a unique constraint.
742+
// The pagination algorithm uses WHERE pagination_key > last_key ORDER BY pagination_key LIMIT batch_size.
743+
// If duplicate values exist, rows may be skipped during iteration, resulting in data loss during the migration.
734744
CascadingPaginationColumnConfig *CascadingPaginationColumnConfig
735745

736746
// SkipTargetVerification is used to enable or disable target verification during moves.

0 commit comments

Comments
 (0)