Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 49 additions & 39 deletions pkg/accounting/accounting.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"sync"
"time"

"github.com/ethersphere/bee/v2/pkg/bigint"
"github.com/ethersphere/bee/v2/pkg/log"
"github.com/ethersphere/bee/v2/pkg/p2p"
"github.com/ethersphere/bee/v2/pkg/pricing"
Expand Down Expand Up @@ -353,7 +354,7 @@ func (c *creditAction) Apply() error {

loggerV2.Debug("credit action apply", "crediting_peer_address", c.peer, "price", c.price, "new_balance", nextBalance)

err = c.accounting.store.Put(peerBalanceKey(c.peer), nextBalance)
err = c.accounting.store.Put(peerBalanceKey(c.peer), &bigint.BigInt{Int: nextBalance})
if err != nil {
return fmt.Errorf("failed to persist balance: %w", err)
}
Expand Down Expand Up @@ -406,7 +407,7 @@ func (c *creditAction) Apply() error {
loggerV2.Debug("credit action apply; decreasing originated balance", "crediting_peer_address", c.peer, "current_balance", nextOriginBalance)
}

err = c.accounting.store.Put(originatedBalanceKey(c.peer), nextOriginBalance)
err = c.accounting.store.Put(originatedBalanceKey(c.peer), &bigint.BigInt{Int: nextOriginBalance})
if err != nil {
return fmt.Errorf("failed to persist originated balance: %w", err)
}
Expand Down Expand Up @@ -519,7 +520,8 @@ func (a *Accounting) settle(peer swarm.Address, balance *accountingPeer) error {

// Balance returns the current balance for the given peer.
func (a *Accounting) Balance(peer swarm.Address) (balance *big.Int, err error) {
err = a.store.Get(peerBalanceKey(peer), &balance)
var w bigint.BigInt
err = a.store.Get(peerBalanceKey(peer), &w)

if err != nil {
if errors.Is(err, storage.ErrNotFound) {
Expand All @@ -528,12 +530,13 @@ func (a *Accounting) Balance(peer swarm.Address) (balance *big.Int, err error) {
return nil, err
}

return balance, nil
return w.Int, nil
}

// OriginatedBalance returns the current balance for the given peer.
func (a *Accounting) OriginatedBalance(peer swarm.Address) (balance *big.Int, err error) {
err = a.store.Get(originatedBalanceKey(peer), &balance)
var w bigint.BigInt
err = a.store.Get(originatedBalanceKey(peer), &w)

if err != nil {
if errors.Is(err, storage.ErrNotFound) {
Expand All @@ -542,12 +545,13 @@ func (a *Accounting) OriginatedBalance(peer swarm.Address) (balance *big.Int, er
return nil, err
}

return balance, nil
return w.Int, nil
}

// SurplusBalance returns the current balance for the given peer.
func (a *Accounting) SurplusBalance(peer swarm.Address) (balance *big.Int, err error) {
err = a.store.Get(peerSurplusBalanceKey(peer), &balance)
var w bigint.BigInt
err = a.store.Get(peerSurplusBalanceKey(peer), &w)

if err != nil {
if errors.Is(err, storage.ErrNotFound) {
Expand All @@ -556,11 +560,11 @@ func (a *Accounting) SurplusBalance(peer swarm.Address) (balance *big.Int, err e
return nil, err
}

if balance.Cmp(big.NewInt(0)) < 0 {
if w.Cmp(big.NewInt(0)) < 0 {
return nil, ErrInvalidValue
}

return balance, nil
return w.Int, nil
}

// CompensatedBalance returns balance decreased by surplus balance
Expand Down Expand Up @@ -682,13 +686,13 @@ func (a *Accounting) Balances() (map[string]*big.Int, error) {
}

if _, ok := s[addr.String()]; !ok {
var storevalue *big.Int
err = a.store.Get(peerBalanceKey(addr), &storevalue)
var w bigint.BigInt
err = a.store.Get(peerBalanceKey(addr), &w)
if err != nil {
return false, fmt.Errorf("get peer %s balance: %w", addr.String(), err)
}

s[addr.String()] = storevalue
s[addr.String()] = w.Int
}

return false, nil
Expand Down Expand Up @@ -866,14 +870,15 @@ func (a *Accounting) PeerDebt(peer swarm.Address) (*big.Int, error) {
accountingPeer.lock.Lock()
defer accountingPeer.lock.Unlock()

balance := new(big.Int)
var w bigint.BigInt
zero := big.NewInt(0)

err := a.store.Get(peerBalanceKey(peer), &balance)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return nil, err
}
err := a.store.Get(peerBalanceKey(peer), &w)
if err != nil && !errors.Is(err, storage.ErrNotFound) {
return nil, err
}
balance := w.Int
if balance == nil {
balance = big.NewInt(0)
}

Expand All @@ -891,14 +896,15 @@ func (a *Accounting) peerLatentDebt(peer swarm.Address) (*big.Int, error) {

accountingPeer := a.getAccountingPeer(peer)

balance := new(big.Int)
var wl bigint.BigInt
zero := big.NewInt(0)

err := a.store.Get(peerBalanceKey(peer), &balance)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return nil, err
}
err := a.store.Get(peerBalanceKey(peer), &wl)
if err != nil && !errors.Is(err, storage.ErrNotFound) {
return nil, err
}
balance := wl.Int
if balance == nil {
balance = big.NewInt(0)
}

Expand All @@ -919,16 +925,20 @@ func (a *Accounting) peerLatentDebt(peer swarm.Address) (*big.Int, error) {
// shadowBalance returns the current debt reduced by any potentially debitable amount stored in shadowReservedBalance
// this represents how much less our debt could potentially be seen by the other party if it's ahead with processing credits corresponding to our shadow reserve
func (a *Accounting) shadowBalance(peer swarm.Address, accountingPeer *accountingPeer) (shadowBalance *big.Int, err error) {
balance := new(big.Int)
var ws bigint.BigInt
zero := big.NewInt(0)

err = a.store.Get(peerBalanceKey(peer), &balance)
err = a.store.Get(peerBalanceKey(peer), &ws)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return zero, nil
}
return nil, err
}
balance := ws.Int
if balance == nil {
balance = zero
}

if balance.Cmp(zero) >= 0 {
return zero, nil
Expand Down Expand Up @@ -986,7 +996,7 @@ func (a *Accounting) NotifyPaymentSent(peer swarm.Address, amount *big.Int, rece

loggerV2.Debug("registering payment sent", "peer_address", peer, "amount", amount, "new_balance", nextBalance)

err = a.store.Put(peerBalanceKey(peer), nextBalance)
err = a.store.Put(peerBalanceKey(peer), &bigint.BigInt{Int: nextBalance})
if err != nil {
a.logger.Error(err, "notify payment sent; failed to persist balance")
return
Expand Down Expand Up @@ -1043,7 +1053,7 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int)

loggerV2.Debug("surplus crediting peer", "peer_address", peer, "amount", amount, "new_balance", increasedSurplus)

err = a.store.Put(peerSurplusBalanceKey(peer), increasedSurplus)
err = a.store.Put(peerSurplusBalanceKey(peer), &bigint.BigInt{Int: increasedSurplus})
if err != nil {
return fmt.Errorf("failed to persist surplus balance: %w", err)
}
Expand All @@ -1064,7 +1074,7 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int)

loggerV2.Debug("crediting peer", "peer_address", peer, "amount", amount, "new_balance", nextBalance)

err = a.store.Put(peerBalanceKey(peer), nextBalance)
err = a.store.Put(peerBalanceKey(peer), &bigint.BigInt{Int: nextBalance})
if err != nil {
return fmt.Errorf("failed to persist balance: %w", err)
}
Expand All @@ -1083,7 +1093,7 @@ func (a *Accounting) NotifyPaymentReceived(peer swarm.Address, amount *big.Int)

loggerV2.Debug("surplus crediting peer due to refreshment", "peer_address", peer, "amount", surplusGrowth, "new_balance", increasedSurplus)

err = a.store.Put(peerSurplusBalanceKey(peer), increasedSurplus)
err = a.store.Put(peerSurplusBalanceKey(peer), &bigint.BigInt{Int: increasedSurplus})
if err != nil {
return fmt.Errorf("failed to persist surplus balance: %w", err)
}
Expand Down Expand Up @@ -1165,7 +1175,7 @@ func (a *Accounting) NotifyRefreshmentSent(peer swarm.Address, attemptedAmount,

newBalance := new(big.Int).Add(currentBalance, amount)

err = a.store.Put(peerBalanceKey(peer), newBalance)
err = a.store.Put(peerBalanceKey(peer), &bigint.BigInt{Int: newBalance})
if err != nil {
a.logger.Error(err, "notifyrefreshmentsent failed to persist balance")
return
Expand Down Expand Up @@ -1206,7 +1216,7 @@ func (a *Accounting) NotifyRefreshmentReceived(peer swarm.Address, amount *big.I

// We allow a refreshment to potentially put us into debt as it was previously negotiated and be limited to the peer's outstanding debt plus shadow reserve
loggerV2.Debug("crediting peer", "peer_address", peer, "amount", amount, "new_balance", nextBalance)
err = a.store.Put(peerBalanceKey(peer), nextBalance)
err = a.store.Put(peerBalanceKey(peer), &bigint.BigInt{Int: nextBalance})
if err != nil {
return fmt.Errorf("failed to persist balance: %w", err)
}
Expand Down Expand Up @@ -1269,7 +1279,7 @@ func (a *Accounting) increaseBalance(peer swarm.Address, _ *accountingPeer, pric
if newSurplusBalance.Cmp(big.NewInt(0)) >= 0 {
loggerV2.Debug("surplus debiting peer", "peer_address", peer, "price", price, "new_balance", newSurplusBalance)

err = a.store.Put(peerSurplusBalanceKey(peer), newSurplusBalance)
err = a.store.Put(peerSurplusBalanceKey(peer), &bigint.BigInt{Int: newSurplusBalance})
if err != nil {
return nil, fmt.Errorf("failed to persist surplus balance: %w", err)
}
Expand All @@ -1290,7 +1300,7 @@ func (a *Accounting) increaseBalance(peer swarm.Address, _ *accountingPeer, pric
// let's store 0 as surplus balance
loggerV2.Debug("surplus debiting peer", "peer_address", peer, "amount", debitIncrease, "new_balance", 0)

err = a.store.Put(peerSurplusBalanceKey(peer), big.NewInt(0))
err = a.store.Put(peerSurplusBalanceKey(peer), &bigint.BigInt{Int: big.NewInt(0)})
if err != nil {
return nil, fmt.Errorf("failed to persist surplus balance: %w", err)
}
Expand All @@ -1308,7 +1318,7 @@ func (a *Accounting) increaseBalance(peer swarm.Address, _ *accountingPeer, pric

loggerV2.Debug("debiting peer", "peer_address", peer, "price", price, "new_balance", nextBalance)

err = a.store.Put(peerBalanceKey(peer), nextBalance)
err = a.store.Put(peerBalanceKey(peer), &bigint.BigInt{Int: nextBalance})
if err != nil {
return nil, fmt.Errorf("failed to persist balance: %w", err)
}
Expand Down Expand Up @@ -1445,12 +1455,12 @@ func (a *Accounting) Connect(peer swarm.Address, fullNode bool) {
accountingPeer.thresholdGrowAt.Set(thresholdGrowStep)
accountingPeer.disconnectLimit.Set(disconnectLimit)

err := a.store.Put(peerBalanceKey(peer), zero)
err := a.store.Put(peerBalanceKey(peer), &bigint.BigInt{Int: zero})
if err != nil {
a.logger.Error(err, "failed to persist balance")
}

err = a.store.Put(peerSurplusBalanceKey(peer), zero)
err = a.store.Put(peerSurplusBalanceKey(peer), &bigint.BigInt{Int: zero})
if err != nil {
a.logger.Error(err, "failed to persist surplus balance")
}
Expand All @@ -1475,7 +1485,7 @@ func (a *Accounting) decreaseOriginatedBalanceTo(peer swarm.Address, limit *big.

// If originated balance is more into the negative domain, set it to limit
if originatedBalance.Cmp(toSet) < 0 {
err = a.store.Put(originatedBalanceKey(peer), toSet)
err = a.store.Put(originatedBalanceKey(peer), &bigint.BigInt{Int: toSet})
if err != nil {
return fmt.Errorf("failed to persist originated balance: %w", err)
}
Expand All @@ -1497,7 +1507,7 @@ func (a *Accounting) decreaseOriginatedBalanceBy(peer swarm.Address, amount *big
// Move originated balance into the positive domain by amount
newOriginatedBalance := new(big.Int).Add(originatedBalance, amount)

err = a.store.Put(originatedBalanceKey(peer), newOriginatedBalance)
err = a.store.Put(originatedBalanceKey(peer), &bigint.BigInt{Int: newOriginatedBalance})
if err != nil {
return fmt.Errorf("failed to persist originated balance: %w", err)
}
Expand Down
18 changes: 18 additions & 0 deletions pkg/bigint/bigint.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,21 @@ func (i *BigInt) UnmarshalJSON(b []byte) error {
func Wrap(i *big.Int) *BigInt {
return &BigInt{Int: i}
}

// MarshalBinary implements encoding.BinaryMarshaler using Gob encoding.
// Panics if the underlying *big.Int is nil, as this indicates a programmer error.
func (i *BigInt) MarshalBinary() ([]byte, error) {
if i.Int == nil {
Comment thread
martinconic marked this conversation as resolved.
panic("bigint: MarshalBinary called on nil Int")
}
return i.GobEncode()
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler using Gob decoding.
func (i *BigInt) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return fmt.Errorf("bigint: UnmarshalBinary called with empty data")
}
i.Int = new(big.Int)
return i.GobDecode(data)
}
88 changes: 88 additions & 0 deletions pkg/bigint/bigint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,94 @@ import (
"github.com/ethersphere/bee/v2/pkg/bigint"
)

func TestBinaryMarshalingRoundTrip(t *testing.T) {
t.Parallel()

tests := []struct {
name string
val *big.Int
}{
{"positive", big.NewInt(123456789)},
{"negative", big.NewInt(-987654321)},
{"zero", big.NewInt(0)},
{"large", new(big.Int).Mul(big.NewInt(math.MaxInt64), big.NewInt(math.MaxInt64))},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

original := bigint.Wrap(tc.val)
data, err := original.MarshalBinary()
if err != nil {
t.Fatalf("MarshalBinary: %v", err)
}

var got bigint.BigInt
if err := got.UnmarshalBinary(data); err != nil {
t.Fatalf("UnmarshalBinary: %v", err)
}

if got.Cmp(tc.val) != 0 {
t.Fatalf("got %v, want %v", got.Int, tc.val)
}
})
}
}

// TestBinaryMarshalingGobCompatibility verifies that MarshalBinary produces
// byte-identical output to big.Int.GobEncode, confirming that nodes upgrading
// from the old code (which stored raw *big.Int via GobEncode) will write
// identical bytes after migration.
func TestBinaryMarshalingGobCompatibility(t *testing.T) {
t.Parallel()

val := big.NewInt(555000)
gobData, err := val.GobEncode()
if err != nil {
t.Fatal(err)
}

newData, err := bigint.Wrap(val).MarshalBinary()
if err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(gobData, newData) {
t.Fatalf("MarshalBinary output differs from GobEncode: got %v, want %v", newData, gobData)
}

var got bigint.BigInt
if err := got.UnmarshalBinary(gobData); err != nil {
t.Fatalf("UnmarshalBinary of gob data: %v", err)
}
if got.Cmp(val) != 0 {
t.Fatalf("got %v, want %v", got.Int, val)
}
}

func TestMarshalBinaryNilPanics(t *testing.T) {
t.Parallel()

defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on MarshalBinary with nil Int, got none")
}
}()

var w bigint.BigInt // Int is nil
_, _ = w.MarshalBinary()
}

func TestUnmarshalBinaryEmptyErrors(t *testing.T) {
t.Parallel()

var w bigint.BigInt
if err := w.UnmarshalBinary([]byte{}); err == nil {
t.Fatal("expected error on UnmarshalBinary with empty data, got nil")
}
}

func TestMarshaling(t *testing.T) {
t.Parallel()

Expand Down
Loading
Loading