Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 64 additions & 12 deletions batch_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

sql "github.com/Shopify/ghostferry/sqlwrapper"

"github.com/go-mysql-org/go-mysql/schema"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -56,14 +57,65 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
return nil
}

startPaginationKeypos, err := values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
var startPaginationKeypos, endPaginationKeypos PaginationKey
var err error

paginationColumn := batch.TableSchema().GetPaginationColumn()

endPaginationKeypos, err := values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
var startValue, endValue uint64
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
startPaginationKeypos = NewUint64Key(startValue)
endPaginationKeypos = NewUint64Key(endValue)

case schema.TYPE_BINARY, schema.TYPE_STRING:
startValueInterface := values[0][batch.PaginationKeyIndex()]
endValueInterface := values[len(values)-1][batch.PaginationKeyIndex()]

getBytes := func(val interface{}) ([]byte, error) {
switch v := val.(type) {
case []byte:
return v, nil
case string:
return []byte(v), nil
default:
return nil, fmt.Errorf("expected binary/string pagination key, got %T", val)
}
}

startValue, err := getBytes(startValueInterface)
if err != nil {
return err
}

endValue, err := getBytes(endValueInterface)
if err != nil {
return err
}

startPaginationKeypos = NewBinaryKey(startValue)
endPaginationKeypos = NewBinaryKey(endValue)

default:
var startValue, endValue uint64
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
startPaginationKeypos = NewUint64Key(startValue)
endPaginationKeypos = NewUint64Key(endValue)
}

db := batch.TableSchema().Schema
Expand All @@ -78,12 +130,12 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {

query, args, err := batch.AsSQLQuery(db, table)
if err != nil {
return fmt.Errorf("during generating sql query at paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, err)
return fmt.Errorf("during generating sql query at paginationKey %s -> %s: %v", startPaginationKeypos.String(), endPaginationKeypos.String(), err)
}

stmt, err := w.stmtCache.StmtFor(w.DB, query)
if err != nil {
return fmt.Errorf("during prepare query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during prepare query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

tx, err := w.DB.Begin()
Expand All @@ -94,14 +146,14 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
_, err = tx.Stmt(stmt).Exec(args...)
if err != nil {
tx.Rollback()
return fmt.Errorf("during exec query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during exec query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

if w.InlineVerifier != nil {
mismatches, err := w.InlineVerifier.CheckFingerprintInline(tx, db, table, batch, w.EnforceInlineVerification)
if err != nil {
tx.Rollback()
return fmt.Errorf("during fingerprint checking for paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during fingerprint checking for paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

if w.EnforceInlineVerification {
Expand All @@ -119,7 +171,7 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
err = tx.Commit()
if err != nil {
tx.Rollback()
return fmt.Errorf("during commit near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during commit near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

// Note that the state tracker expects us the track based on the original
Expand Down
157 changes: 130 additions & 27 deletions compression_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func (e UnsupportedCompressionError) Error() string {
type CompressionVerifier struct {
logger *logrus.Entry

TableSchemaCache TableSchemaCache
supportedAlgorithms map[string]struct{}
tableColumnCompressions TableColumnCompressionConfig
}
Expand All @@ -59,32 +60,66 @@ type CompressionVerifier struct {
// The GetCompressedHashes method checks if the existing table contains compressed data
// and will apply the decompression algorithm to the applicable columns if necessary.
// After the columns are decompressed, the hashes of the data are used to verify equality
func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (map[uint64][]byte, error) {
func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schemaName, tableName, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (map[string][]byte, error) {
c.logger.WithFields(logrus.Fields{
"tag": "compression_verifier",
"table": table,
"table": tableName,
}).Info("decompressing table data before verification")

tableCompression := c.tableColumnCompressions[table]
tableCompression := c.tableColumnCompressions[tableName]

table := c.TableSchemaCache.Get(schemaName, tableName)
if table == nil {
return nil, fmt.Errorf("table %s.%s not found in schema cache", schemaName, tableName)
}
paginationColumns := table.GetPaginationColumns()

// Extract the raw rows using SQL to be decompressed
rows, err := getRows(db, schema, table, paginationKeyColumn, columns, paginationKeys)
rows, err := getRows(db, schemaName, tableName, paginationColumns, columns, paginationKeys)
if err != nil {
return nil, err
}
defer rows.Close()

// Decompress applicable columns and hash the resulting column values for comparison
resultSet := make(map[uint64][]byte)
resultSet := make(map[string][]byte)
numPaginationCols := len(paginationColumns)

for rows.Next() {
rowData, err := ScanByteRow(rows, len(columns)+1)
// Scan: pagination_col1, pagination_col2, ..., data_cols...
rowData, err := ScanByteRow(rows, len(columns)+numPaginationCols)
if err != nil {
return nil, err
}

paginationKey, err := strconv.ParseUint(string(rowData[0]), 10, 64)
if err != nil {
return nil, err
// Build pagination key from columns (works for both single and composite keys)
keys := make([]PaginationKey, len(paginationColumns))
for i, paginationColumn := range paginationColumns {
switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := strconv.ParseUint(string(rowData[i]), 10, 64)
if err != nil {
return nil, err
}
keys[i] = NewUint64Key(paginationKeyUint)

case schema.TYPE_BINARY, schema.TYPE_STRING:
keys[i] = NewBinaryKey(rowData[i])

default:
paginationKeyUint, err := strconv.ParseUint(string(rowData[i]), 10, 64)
if err != nil {
return nil, err
}
keys[i] = NewUint64Key(paginationKeyUint)
}
}

// For single column, use the key directly; for composite, wrap in CompositeKey
var paginationKeyStr string
if len(keys) == 1 {
paginationKeyStr = keys[0].String()
} else {
paginationKeyStr = CompositeKey(keys).String()
}

// Decompress the applicable columns and then hash them together
Expand All @@ -94,14 +129,14 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag
decompressedRowData := [][]byte{}
for idx, column := range columns {
if algorithm, ok := tableCompression[column.Name]; ok {
// rowData contains the result of "SELECT paginationKeyColumn, * FROM ...", so idx+1 to get each column
decompressedColData, err := c.Decompress(table, column.Name, algorithm, rowData[idx+1])
// rowData contains the result of "SELECT paginationKeyCols..., * FROM ...", so idx+numPaginationCols to get each data column
decompressedColData, err := c.Decompress(tableName, column.Name, algorithm, rowData[idx+numPaginationCols])
if err != nil {
return nil, err
}
decompressedRowData = append(decompressedRowData, decompressedColData)
} else {
decompressedRowData = append(decompressedRowData, rowData[idx+1])
decompressedRowData = append(decompressedRowData, rowData[idx+numPaginationCols])
}
}

Expand All @@ -111,20 +146,20 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag
return nil, err
}

resultSet[paginationKey] = decompressedRowHash
resultSet[paginationKeyStr] = decompressedRowHash
}

metrics.Gauge(
"compression_verifier_decompress_rows",
float64(len(resultSet)),
[]MetricTag{{"table", table}},
[]MetricTag{{"table", tableName}},
1.0,
)

logrus.WithFields(logrus.Fields{
"tag": "compression_verifier",
"rows": len(resultSet),
"table": table,
"table": tableName,
}).Debug("decompressed rows will be compared")

return resultSet, nil
Expand Down Expand Up @@ -192,12 +227,13 @@ func (c *CompressionVerifier) verifyConfiguredCompression(tableColumnCompression

// NewCompressionVerifier first checks the map for supported compression algorithms before
// initializing and returning the initialized instance.
func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig) (*CompressionVerifier, error) {
func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig, tableSchemaCache TableSchemaCache) (*CompressionVerifier, error) {
supportedAlgorithms := make(map[string]struct{})
supportedAlgorithms[CompressionSnappy] = struct{}{}

compressionVerifier := &CompressionVerifier{
logger: logrus.WithField("tag", "compression_verifier"),
TableSchemaCache: tableSchemaCache,
supportedAlgorithms: supportedAlgorithms,
tableColumnCompressions: tableColumnCompressions,
}
Expand All @@ -209,14 +245,75 @@ func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig
return compressionVerifier, nil
}

func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (*sqlorig.Rows, error) {
quotedPaginationKey := QuoteField(paginationKeyColumn)
sql, args, err := rowSelector(columns, paginationKeyColumn).
From(QuotedTableNameFromString(schema, table)).
Where(sq.Eq{quotedPaginationKey: paginationKeys}).
OrderBy(quotedPaginationKey).
ToSql()

func getRows(db *sql.DB, schemaName, table string, paginationKeyColumns []*schema.TableColumn, columns []schema.TableColumn, paginationKeys []interface{}) (*sqlorig.Rows, error) {
builder := rowSelector(columns, paginationKeyColumns).
From(QuotedTableNameFromString(schemaName, table))

if len(paginationKeyColumns) == 1 {
// Single column WHERE clause
quotedPaginationKey := QuoteField(paginationKeyColumns[0].Name)
builder = builder.Where(sq.Eq{quotedPaginationKey: paginationKeys})
builder = builder.OrderBy(quotedPaginationKey)
} else {
// Composite key WHERE clause: (col1, col2) IN ((?, ?), (?, ?), ...)
quotedPKCols := make([]string, len(paginationKeyColumns))
for i, col := range paginationKeyColumns {
quotedPKCols[i] = QuoteField(col.Name)
}
tuple := fmt.Sprintf("(%s)", strings.Join(quotedPKCols, ", "))

// Build placeholder tuples for each pagination key string
placeholderTuples := make([]string, len(paginationKeys))
args := make([]interface{}, 0, len(paginationKeys)*len(paginationKeyColumns))

for i, pkInterface := range paginationKeys {
pkStr, ok := pkInterface.(string)
if !ok {
return nil, fmt.Errorf("expected string pagination key for composite key, got %T", pkInterface)
}

// Parse the composite key string (comma-separated)
parts := strings.Split(pkStr, ",")
if len(parts) != len(paginationKeyColumns) {
return nil, fmt.Errorf("pagination key has %d parts but expected %d", len(parts), len(paginationKeyColumns))
}

placeholders := make([]string, len(parts))
for j, part := range parts {
placeholders[j] = "?"
// Convert string representation back to appropriate type
col := paginationKeyColumns[j]
switch col.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
val, err := strconv.ParseUint(part, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse pagination key part %q as uint64: %w", part, err)
}
args = append(args, val)
case schema.TYPE_BINARY, schema.TYPE_STRING:
// For binary keys, the string is hex-encoded
decoded, err := hex.DecodeString(part)
if err != nil {
return nil, fmt.Errorf("failed to decode pagination key part %q: %w", part, err)
}
args = append(args, decoded)
default:
val, err := strconv.ParseUint(part, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse pagination key part %q: %w", part, err)
}
args = append(args, val)
}
}
placeholderTuples[i] = fmt.Sprintf("(%s)", strings.Join(placeholders, ", "))
}

whereClause := fmt.Sprintf("%s IN (%s)", tuple, strings.Join(placeholderTuples, ", "))
builder = builder.Where(whereClause, args...)
builder = builder.OrderBy(strings.Join(quotedPKCols, ", "))
}

sql, args, err := builder.ToSql()
if err != nil {
return nil, err
}
Expand All @@ -238,11 +335,17 @@ func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []sc
return rows, nil
}

func rowSelector(columns []schema.TableColumn, paginationKeyColumn string) sq.SelectBuilder {
func rowSelector(columns []schema.TableColumn, paginationKeyColumns []*schema.TableColumn) sq.SelectBuilder {
// Select all pagination key columns first
selectParts := make([]string, len(paginationKeyColumns))
for i, col := range paginationKeyColumns {
selectParts[i] = QuoteField(col.Name)
}

columnStrs := make([]string, len(columns))
for idx, column := range columns {
columnStrs[idx] = column.Name
}

return sq.Select(fmt.Sprintf("%s, %s", QuoteField(paginationKeyColumn), strings.Join(columnStrs, ",")))
return sq.Select(fmt.Sprintf("%s, %s", strings.Join(selectParts, ", "), strings.Join(columnStrs, ",")))
}
Loading