Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions plugin/evm/vm_warp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,22 +343,29 @@ func testWarpVMTransaction(t *testing.T, scheme string, unsignedMessage *avalanc
GetSubnetIDF: func(context.Context, ids.ID) (ids.ID, error) {
return ids.Empty, nil
},
GetValidatorSetF: func(_ context.Context, height uint64, _ ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
GetWarpValidatorSetF: func(_ context.Context, height uint64, _ ids.ID) (validators.WarpSet, error) {
if height < minimumValidPChainHeight {
return nil, getValidatorSetTestErr
return validators.WarpSet{}, getValidatorSetTestErr
}
return map[ids.NodeID]*validators.GetValidatorOutput{
nodeID1: {
NodeID: nodeID1,
PublicKey: blsPublicKey1,
Weight: 50,
vdrs := validators.WarpSet{
Validators: []*validators.Warp{
{
PublicKey: blsPublicKey1,
PublicKeyBytes: bls.PublicKeyToUncompressedBytes(blsPublicKey1),
Weight: 50,
NodeIDs: []ids.NodeID{nodeID1},
},
{
PublicKey: blsPublicKey2,
PublicKeyBytes: bls.PublicKeyToUncompressedBytes(blsPublicKey2),
Weight: 50,
NodeIDs: []ids.NodeID{nodeID2},
},
},
nodeID2: {
NodeID: nodeID2,
PublicKey: blsPublicKey2,
Weight: 50,
},
}, nil
TotalWeight: 100,
}
avagoUtils.Sort(vdrs.Validators)
return vdrs, nil
},
}

Expand Down Expand Up @@ -645,24 +652,28 @@ func testReceiveWarpMessage(
}
return vm.ctx.SubnetID, nil
},
GetValidatorSetF: func(_ context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
GetWarpValidatorSetF: func(_ context.Context, height uint64, subnetID ids.ID) (validators.WarpSet, error) {
if height < minimumValidPChainHeight {
return nil, getValidatorSetTestErr
return validators.WarpSet{}, getValidatorSetTestErr
}
signers := subnetSigners
if subnetID == constants.PrimaryNetworkID {
signers = primarySigners
}

vdrOutput := make(map[ids.NodeID]*validators.GetValidatorOutput)
vdrs := validators.WarpSet{}
for _, s := range signers {
vdrOutput[s.nodeID] = &validators.GetValidatorOutput{
NodeID: s.nodeID,
PublicKey: s.secret.PublicKey(),
Weight: s.weight,
}
pk := s.secret.PublicKey()
vdrs.Validators = append(vdrs.Validators, &validators.Warp{
PublicKey: pk,
PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk),
Weight: s.weight,
NodeIDs: []ids.NodeID{s.nodeID},
})
vdrs.TotalWeight += s.weight
}
return vdrOutput, nil
avagoUtils.Sort(vdrs.Validators)
return vdrs, nil
},
}

Expand Down
56 changes: 42 additions & 14 deletions precompile/contracts/warp/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"

"github.com/ava-labs/avalanchego/utils/constants"
"github.com/ava-labs/avalanchego/vms/evm/predicate"
"github.com/ava-labs/avalanchego/vms/platformvm/warp"
"github.com/ava-labs/avalanchego/vms/platformvm/warp/payload"
Expand All @@ -16,8 +17,6 @@ import (
"github.com/ava-labs/libevm/log"

"github.com/ava-labs/coreth/precompile/precompileconfig"

warpValidators "github.com/ava-labs/coreth/warp/validators"
)

const (
Expand Down Expand Up @@ -202,24 +201,50 @@ func (c *Config) VerifyPredicate(predicateContext *precompileconfig.PredicateCon
quorumNumerator = c.QuorumNumerator
}

log.Debug("verifying warp message", "warpMsg", warpMsg, "quorumNum", quorumNumerator, "quorumDenom", WarpQuorumDenominator)
log.Debug("verifying warp message",
"warpMsg", warpMsg,
"quorumNum", quorumNumerator,
"quorumDenom", WarpQuorumDenominator,
)

// Wrap validators.State on the chain snow context to special case the Primary Network
state := warpValidators.NewState(
predicateContext.SnowCtx.ValidatorState,
predicateContext.SnowCtx.SubnetID,
sourceSubnetID, err := predicateContext.SnowCtx.ValidatorState.GetSubnetID(
context.TODO(),
warpMsg.SourceChainID,
c.RequirePrimaryNetworkSigners,
)
if err != nil {
log.Debug("failed to retrieve subnetID for chain",
"msgID", warpMsg.ID(),
"chainID", warpMsg.SourceChainID,
"err", err,
)
return fmt.Errorf("%w: %w", errCannotRetrieveValidatorSet, err)
}

// The primary network validator set is never required when verifying
// messages from the P-chain.
//
// For the X-chain and the C-chain, chains can be configured not to require
// the primary network validators to have signed the warp message and to use
// the, likely smaller, local subnet's validator set.
canOptimizePrimaryNetwork := !c.RequirePrimaryNetworkSigners || warpMsg.SourceChainID == constants.PlatformChainID
// If the chain is in the primary network and we don't require verifying
// against the primary network validator set, then we override the source
// subnet ID to the local chain's validator set.
if canOptimizePrimaryNetwork && sourceSubnetID == constants.PrimaryNetworkID {
sourceSubnetID = predicateContext.SnowCtx.SubnetID
}

validatorSet, err := warp.GetCanonicalValidatorSetFromChainID(
context.Background(),
state,
validatorSet, err := predicateContext.SnowCtx.ValidatorState.GetWarpValidatorSet(
context.TODO(),
predicateContext.ProposerVMBlockCtx.PChainHeight,
warpMsg.UnsignedMessage.SourceChainID,
sourceSubnetID,
)
if err != nil {
log.Debug("failed to retrieve canonical validator set", "msgID", warpMsg.ID(), "err", err)
log.Debug("failed to retrieve canonical validator set",
"msgID", warpMsg.ID(),
"subnetID", sourceSubnetID,
"err", err,
)
return fmt.Errorf("%w: %w", errCannotRetrieveValidatorSet, err)
}

Expand All @@ -231,7 +256,10 @@ func (c *Config) VerifyPredicate(predicateContext *precompileconfig.PredicateCon
WarpQuorumDenominator,
)
if err != nil {
log.Debug("failed to verify warp signature", "msgID", warpMsg.ID(), "err", err)
log.Debug("failed to verify warp signature",
"msgID", warpMsg.ID(),
"err", err,
)
return fmt.Errorf("%w: %w", errFailedVerification, err)
}

Expand Down
58 changes: 33 additions & 25 deletions precompile/contracts/warp/predicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ type validatorRange struct {

// createSnowCtx creates a snow.Context instance with a validator state specified by the given validatorRanges
func createSnowCtx(tb testing.TB, validatorRanges []validatorRange) *snow.Context {
getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput)

validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput)
for _, validatorRange := range validatorRanges {
for i := validatorRange.start; i < validatorRange.end; i++ {
validatorOutput := &validators.GetValidatorOutput{
Expand All @@ -203,20 +202,19 @@ func createSnowCtx(tb testing.TB, validatorRanges []validatorRange) *snow.Contex
if validatorRange.publicKey {
validatorOutput.PublicKey = testVdrs[i].vdr.PublicKey
}
getValidatorsOutput[testVdrs[i].nodeID] = validatorOutput
validatorSet[testVdrs[i].nodeID] = validatorOutput
}
}

snowCtx := snowtest.Context(tb, snowtest.CChainID)
state := &validatorstest.State{
snowCtx.ValidatorState = &validatorstest.State{
GetSubnetIDF: func(context.Context, ids.ID) (ids.ID, error) {
return sourceSubnetID, nil
},
GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
return getValidatorsOutput, nil
GetWarpValidatorSetF: func(context.Context, uint64, ids.ID) (validators.WarpSet, error) {
return validators.FlattenValidatorSet(validatorSet)
},
}
snowCtx.ValidatorState = state
return snowCtx
}

Expand Down Expand Up @@ -251,20 +249,29 @@ func testWarpMessageFromPrimaryNetwork(t *testing.T, requirePrimaryNetworkSigner
unsignedMsg, err := avalancheWarp.NewUnsignedMessage(constants.UnitTestID, cChainID, addressedCall.Bytes())
require.NoError(err)

getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput)
blsSignatures := make([]*bls.Signature, 0, numKeys)
var (
warpValidators = validators.WarpSet{
Validators: make([]*validators.Warp, 0, numKeys),
TotalWeight: 20 * uint64(numKeys),
}
blsSignatures = make([]*bls.Signature, 0, numKeys)
)
for i := 0; i < numKeys; i++ {
sig, err := testVdrs[i].sk.Sign(unsignedMsg.Bytes())
vdr := testVdrs[i]
sig, err := vdr.sk.Sign(unsignedMsg.Bytes())
require.NoError(err)

validatorOutput := &validators.GetValidatorOutput{
NodeID: testVdrs[i].nodeID,
Weight: 20,
PublicKey: testVdrs[i].vdr.PublicKey,
}
getValidatorsOutput[testVdrs[i].nodeID] = validatorOutput
blsSignatures = append(blsSignatures, sig)

pk := vdr.sk.PublicKey()
warpValidators.Validators = append(warpValidators.Validators, &validators.Warp{
PublicKey: pk,
PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk),
Weight: 20,
NodeIDs: []ids.NodeID{vdr.nodeID},
})
}
agoUtils.Sort(warpValidators.Validators)

aggregateSignature, err := bls.AggregateSignatures(blsSignatures)
require.NoError(err)
bitSet := set.NewBits()
Expand All @@ -288,13 +295,13 @@ func testWarpMessageFromPrimaryNetwork(t *testing.T, requirePrimaryNetworkSigner
require.Equal(chainID, cChainID)
return constants.PrimaryNetworkID, nil // Return Primary Network SubnetID
},
GetValidatorSetF: func(_ context.Context, _ uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
GetWarpValidatorSetF: func(_ context.Context, _ uint64, subnetID ids.ID) (validators.WarpSet, error) {
expectedSubnetID := snowCtx.SubnetID
if requirePrimaryNetworkSigners {
expectedSubnetID = constants.PrimaryNetworkID
}
require.Equal(expectedSubnetID, subnetID)
return getValidatorsOutput, nil
return warpValidators, nil
},
}

Expand Down Expand Up @@ -721,25 +728,26 @@ func makeWarpPredicateTests(tb testing.TB) map[string]precompiletest.PredicateTe
testName := fmt.Sprintf("%d validators w/ %d signers/repeated PublicKeys", totalNodes, numSigners)

pred := createPredicate(numSigners)
getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput, totalNodes)
validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, totalNodes)
for i := 0; i < totalNodes; i++ {
getValidatorsOutput[testVdrs[i].nodeID] = &validators.GetValidatorOutput{
validatorSet[testVdrs[i].nodeID] = &validators.GetValidatorOutput{
NodeID: testVdrs[i].nodeID,
Weight: 20,
PublicKey: testVdrs[i%numSigners].vdr.PublicKey,
}
}
warpValidators, err := validators.FlattenValidatorSet(validatorSet)
require.NoError(tb, err)

snowCtx := snowtest.Context(tb, snowtest.CChainID)
state := &validatorstest.State{
snowCtx.ValidatorState = &validatorstest.State{
GetSubnetIDF: func(context.Context, ids.ID) (ids.ID, error) {
return sourceSubnetID, nil
},
GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
return getValidatorsOutput, nil
GetWarpValidatorSetF: func(context.Context, uint64, ids.ID) (validators.WarpSet, error) {
return warpValidators, nil
},
}
snowCtx.ValidatorState = state

predicateTests[testName] = createValidPredicateTest(snowCtx, uint64(numSigners), pred)
}
Expand Down
2 changes: 1 addition & 1 deletion warp/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warp.Uns
return nil, err
}

validatorSet, err := warp.GetCanonicalValidatorSetFromSubnetID(ctx, validatorState, pChainHeight, subnetID)
validatorSet, err := validatorState.GetWarpValidatorSet(ctx, pChainHeight, subnetID)
if err != nil {
return nil, fmt.Errorf("failed to get validator set: %w", err)
}
Expand Down
55 changes: 0 additions & 55 deletions warp/validators/state.go

This file was deleted.

Loading
Loading