Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lnwire/channel_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (a *ChannelUpdate1) Decode(r io.Reader, _ uint32) error {
var inboundFee = a.InboundFee.Zero()
typeMap, err := tlvRecords.ExtractRecords(&inboundFee)
if err != nil {
return err
return fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
}

val, ok := typeMap[a.InboundFee.TlvType()]
Expand Down
7 changes: 7 additions & 0 deletions lnwire/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ import (
"io"
)

var (
// ErrParsingExtraTLVBytes is returned when we attempt to parse
// extra opaque bytes as a TLV stream, but the parsing fails due to
// and invalid TLV stream.
ErrParsingExtraTLVBytes = fmt.Errorf("error parsing extra TLV bytes")
)

// FundingError represents a set of errors that can be encountered and sent
// during the funding workflow.
type FundingError uint8
Expand Down
2 changes: 1 addition & 1 deletion lnwire/onion_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ func DecodeFailureMessage(r io.Reader, pver uint32) (FailureMessage, error) {
case Serializable:
if err := f.Decode(r, pver); err != nil {
return nil, fmt.Errorf("unable to decode error "+
"update (type=%T): %v", failure, err)
"update (type=%T): %w", failure, err)
}
}

Expand Down
12 changes: 11 additions & 1 deletion payments/db/kv_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2099,9 +2099,19 @@ func deserializeHTLCFailInfo(r io.Reader) (*HTLCFailInfo, error) {
f.Message, err = lnwire.DecodeFailureMessage(
bytes.NewReader(failureBytes), 0,
)
if err != nil {
if err != nil &&
!errors.Is(err, lnwire.ErrParsingExtraTLVBytes) {

return nil, err
}

// In case we have an invalid TLV stream regarding the extra
// tlv data we still continue with the decoding of the
// HTLCFailInfo.
if errors.Is(err, lnwire.ErrParsingExtraTLVBytes) {
log.Warnf("Failed to decode extra TLV bytes for "+
"failure message: %v", err)
}
}

var reason byte
Expand Down
119 changes: 119 additions & 0 deletions payments/db/kv_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package paymentsdb

import (
"bytes"
"encoding/binary"
"io"
"math"
"reflect"
"testing"
"time"

"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
Expand Down Expand Up @@ -1070,3 +1074,118 @@ func TestLazySessionKeyDeserialize(t *testing.T) {
sessionKey := attempt.SessionKey()
require.Equal(t, priv, sessionKey)
}

// TestDeserializeHTLCFailInfoInvalidTLV tests that deserializeHTLCFailInfo
// handles invalid extra tlv data gracefully by not failing.
func TestDeserializeHTLCFailInfoInvalidTLV(t *testing.T) {
// Create a channel update with valid data first, then encode it.
testSig := &ecdsa.Signature{}
sig, _ := lnwire.NewSigFromSignature(testSig)
chanUpdate := &lnwire.ChannelUpdate1{
Signature: sig,
ShortChannelID: lnwire.NewShortChanIDFromInt(1),
Timestamp: 1,
MessageFlags: 0,
ChannelFlags: 1,
ExtraOpaqueData: make([]byte, 0),
}

var chanUpdateBuf bytes.Buffer
err := chanUpdate.Encode(&chanUpdateBuf, 0)
require.NoError(t, err)

// Append invalid inbound fee TLV record to the encoded channel update.
// The inbound fee TLV has type 55555 and should have 8 bytes of data
// (2 uint32 values: BaseFee and FeeRate). We create an invalid one by
// using the correct type but with incomplete data (only 6 bytes
// instead of 8).
var invalidInboundFeeTLV bytes.Buffer

// Write type 55555 as varint: 0xfd + 2 bytes (canonical encoding)
err = invalidInboundFeeTLV.WriteByte(0xfd)
require.NoError(t, err)

var typeBytes [2]byte
binary.BigEndian.PutUint16(typeBytes[:], 55555)
_, err = invalidInboundFeeTLV.Write(typeBytes[:])
require.NoError(t, err)

// Write length as 8 (single byte since 8 < 0xfd, no varint needed)
err = invalidInboundFeeTLV.WriteByte(8)
require.NoError(t, err)

// Write only 6 bytes of value data (incomplete, should be 8 bytes)
var valueBytes [6]byte
binary.BigEndian.PutUint32(valueBytes[0:4], 1)
binary.BigEndian.PutUint16(valueBytes[4:6], 2)
_, err = invalidInboundFeeTLV.Write(valueBytes[:])
require.NoError(t, err)

_, err = chanUpdateBuf.Write(invalidInboundFeeTLV.Bytes())
require.NoError(t, err)

// Manually create a TemporaryChannelFailure failure message with the
// corrupted channel update bytes.
var failureMsgBuf bytes.Buffer

// Write the failure code.
err = lnwire.WriteUint16(
&failureMsgBuf, uint16(lnwire.CodeTemporaryChannelFailure),
)
require.NoError(t, err)

// Write the length of the channel update (including invalid TLV).
err = lnwire.WriteUint16(&failureMsgBuf, uint16(chanUpdateBuf.Len()))
require.NoError(t, err)

// Write the channel update bytes with invalid TLV appended.
_, err = failureMsgBuf.Write(chanUpdateBuf.Bytes())
require.NoError(t, err)

_, err = lnwire.DecodeFailureMessage(&failureMsgBuf, 0)
require.ErrorIs(t, err, lnwire.ErrParsingExtraTLVBytes)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)

// Create an HTLCFailInfo and serialize it with the corrupted failure
// message.
failInfo := &HTLCFailInfo{
FailTime: time.Now(),
Reason: HTLCFailMessage,
FailureSourceIndex: 2,
}

var buf bytes.Buffer

// Manually serialize the HTLCFailInfo with the corrupted failure bytes.
err = serializeTime(&buf, failInfo.FailTime)
require.NoError(t, err)

// Write the failure message bytes.
err = wire.WriteVarBytes(&buf, 0, failureMsgBuf.Bytes())
require.NoError(t, err)

// Write reason and failure source index.
err = WriteElements(
&buf, byte(failInfo.Reason), failInfo.FailureSourceIndex,
)
require.NoError(t, err)

// Now deserialize the HTLCFailInfo - this should NOT fail despite the
// invalid TLV data.
deserializedFailInfo, err := deserializeHTLCFailInfo(&buf)
require.NoError(t, err, "deserializeHTLCFailInfo should not fail "+
"with invalid TLV data")
require.NotNil(t, deserializedFailInfo)

// Verify the basic fields are correctly deserialized.
require.Equal(t, failInfo.Reason, deserializedFailInfo.Reason)
require.Equal(t, failInfo.FailureSourceIndex,
deserializedFailInfo.FailureSourceIndex)

// Verify the failure message is nil because the decoding failed
// due to invalid TLV data. The important part is that the
// HTLCFailInfo deserialization still succeeded despite the invalid
// TLV data in the failure message.
require.Nil(t, deserializedFailInfo.Message,
"Message should be nil when TLV parsing fails")
}
Loading