Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion htlcswitch/hop/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte,

// Attempt to process the Sphinx packet. We include the payment hash of
// the HTLC as it's authenticated within the Sphinx packet itself as
// associated data in order to thwart attempts a replay attacks. In the
// associated data in order to thwart replay attacks. In the
// case of a replay, an attacker is *forced* to use the same payment
// hash twice, thereby losing their money entirely.
sphinxPacket, err := p.router.ReconstructOnionPacket(
Expand Down Expand Up @@ -737,6 +737,69 @@ func (r *DecodeHopIteratorResponse) Result() (Iterator, lnwire.FailCode) {
return r.HopIterator, r.FailCode
}

// DecodeHopIterator attempts to decode a valid sphinx packet from the passed
// io.Reader instance using the rHash as the associated data when checking the
// relevant MACs during the decoding process.
Comment thread
ziggie1984 marked this conversation as resolved.
func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte,
incomingCltv uint32, incomingAmount lnwire.MilliSatoshi,
Comment thread
ziggie1984 marked this conversation as resolved.
blindingPoint lnwire.BlindingPointRecord) (Iterator, lnwire.FailCode) {

onionPkt := &sphinx.OnionPacket{}
if err := onionPkt.Decode(r); err != nil {
switch {
case errors.Is(err, sphinx.ErrInvalidOnionVersion):
return nil, lnwire.CodeInvalidOnionVersion
case errors.Is(err, sphinx.ErrInvalidOnionKey):
return nil, lnwire.CodeInvalidOnionKey
default:
log.Errorf("unable to decode onion packet: %v", err)
return nil, lnwire.CodeInvalidOnionKey
Comment thread
ziggie1984 marked this conversation as resolved.
}
}

// If a blinding point was provided in the update_add_htlc message,
// pass it through so the sphinx router can derive the correct shared
// secret for blinded hops.
var opts []sphinx.ProcessOnionOpt
blindingPoint.WhenSome(func(
b tlv.RecordT[lnwire.BlindingPointTlvType,
*btcec.PublicKey]) {

opts = append(opts, sphinx.WithBlindingPoint(b.Val))
})
Comment thread
ziggie1984 marked this conversation as resolved.

// Attempt to process the Sphinx packet. We include the payment hash of
// the HTLC as it's authenticated within the Sphinx packet itself as
// associated data in order to thwart replay attacks. In the
// case of a replay, an attacker is *forced* to use the same payment
Comment thread
ziggie1984 marked this conversation as resolved.
// hash twice, thereby losing their money entirely.
sphinxPacket, err := p.router.ProcessOnionPacket(
onionPkt, rHash, incomingCltv, opts...,
)
if err != nil {
switch {
case errors.Is(err, sphinx.ErrInvalidOnionVersion):
return nil, lnwire.CodeInvalidOnionVersion
case errors.Is(err, sphinx.ErrInvalidOnionHMAC):
return nil, lnwire.CodeInvalidOnionHmac
case errors.Is(err, sphinx.ErrInvalidOnionKey):
return nil, lnwire.CodeInvalidOnionKey
default:
log.Errorf("unable to process onion packet: %v", err)
return nil, lnwire.CodeInvalidOnionKey
}
Comment thread
ziggie1984 marked this conversation as resolved.
}

return makeSphinxHopIterator(p.router, onionPkt, sphinxPacket,
BlindingKit{
Processor: p.router,
UpdateAddBlinding: blindingPoint,
IncomingAmount: incomingAmount,
IncomingCltv: incomingCltv,
}, rHash,
), lnwire.CodeNone
}

// DecodeHopIterators performs batched decoding and validation of incoming
// sphinx packets. For the same `id`, this method will return the same iterators
// and failcodes upon subsequent invocations.
Expand Down
164 changes: 164 additions & 0 deletions htlcswitch/hop/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,170 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
}
}

// TestDecodeHopIterator tests that DecodeHopIterator can successfully process
// a real onion packet constructed by the sphinx library and return a valid hop
// iterator with the correct forwarding information. It also tests various error
// cases such as truncated packets and corrupted HMACs.
func TestDecodeHopIterator(t *testing.T) {
t.Parallel()

// Generate a fresh private key for our onion processor (the
// "receiving" node).
receiverPrivKey, err := btcec.NewPrivateKey()
require.NoError(t, err)

sphinxRouter := sphinx.NewRouter(
&sphinx.PrivKeyECDH{PrivKey: receiverPrivKey},
sphinx.NewNoOpReplayLog(),
)
require.NoError(t, sphinxRouter.Start())
defer sphinxRouter.Stop()

processor := NewOnionProcessor(sphinxRouter)

// Session key used by the "sender" to construct the onion.
sessionKey, err := btcec.NewPrivateKey()
require.NoError(t, err)

// Build a TLV payload for the final hop with amount and CLTV.
var (
fwdAmt uint64 = 500_000
outgoingCltv uint32 = 144
incomingCltv uint32 = 200
incomingAmt lnwire.MilliSatoshi = 600_000
noBlinding lnwire.BlindingPointRecord
)
var payloadBuf bytes.Buffer
Comment on lines +132 to +139
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this in one var block

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merged both var blocks into one ✅

tlvRecords := []tlv.Record{
record.NewAmtToFwdRecord(&fwdAmt),
record.NewLockTimeRecord(&outgoingCltv),
}
tlvStream, err := tlv.NewStream(tlvRecords...)
require.NoError(t, err)
require.NoError(t, tlvStream.Encode(&payloadBuf))

// Build a one-hop payment path to the receiver.
var path sphinx.PaymentPath
path[0] = sphinx.OnionHop{
NodePub: *receiverPrivKey.PubKey(),
HopPayload: sphinx.HopPayload{
Type: sphinx.PayloadTLV,
Payload: payloadBuf.Bytes(),
},
}

// Create the onion packet.
rHash := [32]byte{0xaa, 0xbb, 0xcc}
onionPkt, err := sphinx.NewOnionPacket(
&path, sessionKey, rHash[:],
sphinx.DeterministicPacketFiller,
)
require.NoError(t, err)

// serializeOnion is a helper that encodes an onion packet to bytes.
serializeOnion := func(pkt *sphinx.OnionPacket) []byte {
var buf bytes.Buffer
require.NoError(t, pkt.Encode(&buf))
return buf.Bytes()
}

validOnionBytes := serializeOnion(onionPkt)

tests := []struct {
name string
onionBytes []byte
rHash []byte
expectedFail lnwire.FailCode
checkPayload bool
}{
{
name: "valid onion",
onionBytes: validOnionBytes,
rHash: rHash[:],
expectedFail: lnwire.CodeNone,
checkPayload: true,
},
{
name: "truncated packet",
onionBytes: validOnionBytes[:10],
rHash: rHash[:],
expectedFail: lnwire.CodeInvalidOnionKey,
},
{
name: "empty reader",
onionBytes: []byte{},
rHash: rHash[:],
expectedFail: lnwire.CodeInvalidOnionKey,
},
{
name: "corrupted HMAC",
onionBytes: func() []byte {
corrupted := make([]byte, len(validOnionBytes))
copy(corrupted, validOnionBytes)
// Flip a byte in the HMAC (last 32 bytes of
// the packet).
corrupted[len(corrupted)-1] ^= 0xff

return corrupted
}(),
rHash: rHash[:],
expectedFail: lnwire.CodeInvalidOnionHmac,
},
{
name: "wrong payment hash",
onionBytes: validOnionBytes,
rHash: bytes.Repeat([]byte{0xff}, 32),
expectedFail: lnwire.CodeInvalidOnionHmac,
},
{
name: "invalid version byte",
onionBytes: func() []byte {
corrupted := make([]byte, len(validOnionBytes))
copy(corrupted, validOnionBytes)
// Set an invalid version (first byte).
corrupted[0] = 0xff

return corrupted
}(),
rHash: rHash[:],
expectedFail: lnwire.CodeInvalidOnionVersion,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

reader := bytes.NewReader(tc.onionBytes)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the subtests don't call t.Parallel(). The outer test does, and the sphinx router is read-only across all subtests, so parallel execution is safe. lnd style generally includes t.Parallel() as the first line inside each t.Run closure.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this out. updated so the subtests now run in parallel.

iterator, failCode := processor.DecodeHopIterator(
reader, tc.rHash, incomingCltv, incomingAmt,
noBlinding,
)
Comment on lines +243 to +244
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The processor.DecodeHopIterator function call exceeds the recommended line length. Please wrap the arguments according to the style guide (Repository Style Guide, lines 170-186).

				reader, tc.rHash, incomingCltv,
				incomingAmt, noBlinding,


require.Equal(t, tc.expectedFail, failCode)

if !tc.checkPayload {
return
}

require.NotNil(t, iterator)

payload, role, err := iterator.HopPayload()
require.NoError(t, err)
require.Equal(t, RouteRoleCleartext, role)

fwdInfo := payload.ForwardingInfo()
require.Equal(
t, lnwire.MilliSatoshi(fwdAmt),
fwdInfo.AmountToForward,
)
require.Equal(
t, outgoingCltv, fwdInfo.OutgoingCTLV,
)
})
}
}

// TestForwardingAmountCalc tests calculation of forwarding amounts from the
// hop's forwarding parameters.
func TestForwardingAmountCalc(t *testing.T) {
Expand Down
15 changes: 15 additions & 0 deletions htlcswitch/hop/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@ package hop

import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)

// Subsystem defines the logging sub system name of this package.
const Subsystem = "HOPS"

// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
Comment on lines 11 to 14
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment for the log variable states that it 'will not perform any logging by default until the caller requests it.' However, the newly added init() function now initializes the logger by default. Please update the comment to reflect this change for clarity.

Suggested change
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// log is a logger that is initialized with no output filters. This
// means the package will perform logging by default, but the caller can
// customize it using UseLogger.
var log btclog.Logger

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach is copied from other lnd subpackage log.go files (like htlcswitch/log.go and funding/log.go) so I think it's probably okay to keep.


// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger(Subsystem, nil))
}

// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}

// UseLogger uses a specified Logger to output package logging info. This
// function is called from the parent package htlcswitch logger initialization.
func UseLogger(logger btclog.Logger) {
Expand Down
4 changes: 3 additions & 1 deletion htlcswitch/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,8 @@ func newMockIteratorDecoder() *mockIteratorDecoder {
}

func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
cltv uint32) (hop.Iterator, lnwire.FailCode) {
cltv uint32, _ lnwire.MilliSatoshi,
_ lnwire.BlindingPointRecord) (hop.Iterator, lnwire.FailCode) {

var b [4]byte
_, err := r.Read(b[:])
Expand Down Expand Up @@ -540,6 +541,7 @@ func (p *mockIteratorDecoder) DecodeHopIterators(id []byte,
for _, req := range reqs {
iterator, failcode := p.DecodeHopIterator(
req.OnionReader, req.RHash, req.IncomingCltv,
req.IncomingAmount, req.BlindingPoint,
)

if p.decodeFail {
Expand Down
Loading