Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 16 additions & 73 deletions halo/attest/keeper/cpayload.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

"github.com/omni-network/omni/halo/attest/types"
"github.com/omni-network/omni/lib/errors"
"github.com/omni-network/omni/lib/xchain"
"github.com/omni-network/omni/lib/log"
evmenginetypes "github.com/omni-network/omni/octane/evmengine/types"

abci "github.com/cometbft/cometbft/abci/types"
Expand All @@ -26,87 +26,30 @@ var _ evmenginetypes.VoteExtensionProvider = (*Keeper)(nil)
// PrepareVotes returns the cosmosSDK transaction MsgAddVotes that will include all the validator votes included
// in the previous block's vote extensions into the attest module.
//
// Note that the commit is trusted to be valid and only contains valid VEs from the previous block as
// provided by a trusted cometBFT.
func (k *Keeper) PrepareVotes(ctx context.Context, commit abci.ExtendedCommitInfo) ([]sdk.Msg, error) {
// Note that the commit is assumed to be valid and only contains valid VEs from the previous block as
// provided by a trusted cometBFT. Some votes (contained inside VE) may however be invalid, they are discarded.
func (k *Keeper) PrepareVotes(ctx context.Context, commit abci.ExtendedCommitInfo, commitHeight uint64) ([]sdk.Msg, error) {
sdkCtx := sdk.UnwrapSDKContext(ctx)
if err := baseapp.ValidateVoteExtensions(sdkCtx, k.skeeper, sdkCtx.BlockHeight(), sdkCtx.ChainID(), commit); err != nil {
// The VEs in LastLocalCommit is expected to be valid
if err := baseapp.ValidateVoteExtensions(sdkCtx, k.skeeper, 0, "", commit); err != nil {
return nil, errors.Wrap(err, "validate extensions [BUG]")
}

// Adapt portal registry to the supportedChainFunc signature.
supportedChainFunc := func(ctx context.Context, chainVersion xchain.ChainVersion) (bool, error) {
chainVersions, err := k.portalRegistry.ConfLevels(ctx)
if err != nil {
return false, err
}

for _, confLevel := range chainVersions[chainVersion.ID] {
if confLevel == chainVersion.ConfLevel {
return true, nil
}
}

return false, nil
}

msg, err := votesFromLastCommit(
ctx,
k.windowCompare,
supportedChainFunc,
commit,
)
if err != nil {
return nil, err
}

return []sdk.Msg{msg}, nil
}

type windowCompareFunc func(context.Context, xchain.ChainVersion, uint64) (int, error)
type supportedChainFunc func(context.Context, xchain.ChainVersion) (bool, error)

// votesFromLastCommit returns the aggregated votes contained in vote extensions
// of the last local commit.
func votesFromLastCommit(
ctx context.Context,
windowCompare windowCompareFunc,
supportedChain supportedChainFunc,
info abci.ExtendedCommitInfo,

) (*types.MsgAddVotes, error) {
// Verify and discard invalid votes.
// Votes inside the VEs are NOT guaranteed to be valid, since
// VerifyVoteExtension isn't called after quorum is reached.
var allVotes []*types.Vote
for _, vote := range info.Votes {
for _, vote := range commit.Votes {
if vote.BlockIdFlag != cmtproto.BlockIDFlagCommit {
continue // Skip non block votes
continue // Skip non-committed votes
}
votes, ok, err := votesFromExtension(vote.VoteExtension)

selected, _, err := k.parseAndVerifyVoteExtension(sdkCtx, vote.Validator.Address, vote.VoteExtension, commitHeight) //nolint:contextcheck // sdkCtx passed
if err != nil {
return nil, err
} else if !ok {
log.Warn(ctx, "Discarding invalid vote extension", err, log.Hex7("validator", vote.Validator.Address))
continue
}

var selected []*types.Vote
for _, v := range votes.Votes {
if ok, err := supportedChain(ctx, v.AttestHeader.XChainVersion()); err != nil {
return nil, err
} else if !ok {
// Skip votes for unsupported chains.
continue
}

cmp, err := windowCompare(ctx, v.AttestHeader.XChainVersion(), v.AttestHeader.AttestOffset)
if err != nil {
return nil, err
} else if cmp != 0 {
// Skip votes that are not in the current window anymore.
continue
}

selected = append(selected, v)
}

allVotes = append(allVotes, selected...)
}

Expand All @@ -115,10 +58,10 @@ func votesFromLastCommit(
return nil, err
}

return &types.MsgAddVotes{
return []sdk.Msg{&types.MsgAddVotes{
Authority: authtypes.NewModuleAddress(types.ModuleName).String(),
Votes: votes,
}, nil
}}, nil
}

// aggregateVotes aggregates the provided attestations by block header.
Expand Down
77 changes: 18 additions & 59 deletions halo/attest/keeper/cpayload_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
package keeper

import (
"context"
"testing"

"github.com/omni-network/omni/halo/attest/types"
"github.com/omni-network/omni/lib/k1util"
"github.com/omni-network/omni/lib/xchain"

abci "github.com/cometbft/cometbft/abci/types"
k1 "github.com/cometbft/cometbft/crypto/secp256k1"
types1 "github.com/cometbft/cometbft/proto/tendermint/types"

"github.com/ethereum/go-ethereum/common"

"github.com/cosmos/gogoproto/proto"
fuzz "github.com/google/gofuzz"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -66,33 +62,25 @@ func TestVotesFromCommitNonUnique(t *testing.T) {
require.Len(t, aggs, 3)
}

func TestVotesFromCommit(t *testing.T) {
func TestAggregateVotes(t *testing.T) {
t.Parallel()
fuzzer := fuzz.New().NilChance(0)

var blockHash common.Hash
fuzzer.Fuzz(&blockHash)

// Generate attestations for following matrix: chains, vals, offset batches
const skipVal = 2 // Skip this validator
const skipChain = 300 // Skip this chain (out of window)
chains := []uint64{100, 200, 300}
chains := []uint64{100, 200}
vals := []k1.PrivKey{k1.GenPrivKey(), k1.GenPrivKey(), k1.GenPrivKey()}
batches := [][]uint64{{1, 2}, {3}, { /*empty*/ }}

expected := make(map[[32]byte]map[xchain.SigTuple]bool)
total := 2 * 3 // 2 chains * 3 heights

var evotes []abci.ExtendedVoteInfo
var allVotes []*types.Vote
for _, chain := range chains {
for i, val := range vals {
flag := types1.BlockIDFlagCommit
if i == skipVal {
flag = types1.BlockIDFlagAbsent
}

for _, val := range vals {
for _, batch := range batches {
var votes []*types.Vote
for _, offset := range batch {
addr, err := k1util.PubKeyToAddress(val.PubKey())
require.NoError(t, err)
Expand All @@ -118,58 +106,29 @@ func TestVotesFromCommit(t *testing.T) {
},
}

if i != skipVal && chain != skipChain {
sig := xchain.SigTuple{
ValidatorAddress: addr,
Signature: sig,
}
attRoot, err := vote.AttestationRoot()
require.NoError(t, err)

if _, ok := expected[attRoot]; !ok {
expected[attRoot] = make(map[xchain.SigTuple]bool)
}
expected[attRoot][sig] = true
}
votes = append(votes, vote)
}
attRoot, err := vote.AttestationRoot()
require.NoError(t, err)

bz, err := proto.Marshal(&types.Votes{
Votes: votes,
})
require.NoError(t, err)
if _, ok := expected[attRoot]; !ok {
expected[attRoot] = make(map[xchain.SigTuple]bool)
}
expected[attRoot][xchain.SigTuple{
ValidatorAddress: addr,
Signature: sig,
}] = true

evotes = append(evotes, abci.ExtendedVoteInfo{
VoteExtension: bz,
BlockIdFlag: flag,
})
allVotes = append(allVotes, vote)
}
}
}
}

info := abci.ExtendedCommitInfo{
Round: 99,
Votes: evotes,
}

comparer := func(ctx context.Context, chainVer xchain.ChainVersion, height uint64) (int, error) {
if chainVer.ID == skipChain {
return 1, nil
}

return 0, nil
}

supported := func(context.Context, xchain.ChainVersion) (bool, error) {
return true, nil
}

resp, err := votesFromLastCommit(context.Background(), comparer, supported, info)
aggs, err := aggregateVotes(allVotes)
require.NoError(t, err)

require.Len(t, resp.Votes, total)
require.Len(t, aggs, total)

for _, agg := range resp.Votes {
for _, agg := range aggs {
attRoot, err := agg.AttestationRoot()
require.NoError(t, err)
for _, s := range agg.Signatures {
Expand Down
79 changes: 50 additions & 29 deletions halo/attest/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,77 +732,98 @@ func (k *Keeper) VerifyVoteExtension(ctx sdk.Context, req *abci.RequestVerifyVot
Status: abci.ResponseVerifyVoteExtension_REJECT,
}

_, ok, err := k.parseAndVerifyVoteExtension(ctx, req.ValidatorAddress, req.VoteExtension, uint64(req.Height))
if err != nil {
log.Warn(ctx, "Rejecting vote extension", err, log.Hex7("validator", req.ValidatorAddress))
return respAccept, nil
} else if !ok {
log.Warn(ctx, "Rejecting vote extension containing vote behind window", nil, log.Hex7("validator", req.ValidatorAddress))
return respReject, nil
}

return respAccept, nil
}

// parseAndVerifyVoteExtension returns a list of valid vote extensions and true if all votes are valid,
// or an error if any validation failed (except "vote-behind-window").
//
// vote-behind-window:
// - is not valid,
// - but is not always considered an error (as it is expected in PrepareProposal).
// - false is returned in this case,
// - to indicate that not all votes are valid (nor returned).
func (k *Keeper) parseAndVerifyVoteExtension(
ctx sdk.Context,
valAddr []byte,
voteExt []byte,
voteHeight uint64,
) ([]*types.Vote, bool, error) {
cChainID, err := netconf.ConsensusChainIDStr2Uint64(ctx.ChainID())
if err != nil {
return nil, errors.Wrap(err, "parse chain id")
return nil, false, errors.Wrap(err, "parse chain id")
}

// Get the ethereum address of the validator
ethAddr, err := k.getValEthAddr(ctx, req.ValidatorAddress)
ethAddr, err := k.getValEthAddr(ctx, valAddr, voteHeight)
if err != nil {
return nil, err // This error should never occur
return nil, false, err // This error should never occur
}

// Adding logging attributes to sdk context is a bit tricky
ctx = ctx.WithContext(log.WithCtx(ctx, log.Hex7("validator", req.ValidatorAddress)))

votes, ok, err := votesFromExtension(req.VoteExtension)
votes, ok, err := votesFromExtension(voteExt)
if err != nil {
log.Warn(ctx, "Rejecting invalid vote extension", err)
return respReject, nil
return nil, false, errors.Wrap(err, "votes from extension")
} else if !ok {
return respAccept, nil
return nil, true, nil // Empty vote extension is fine
} else if umath.Len(votes.Votes) > k.voteExtLimit {
log.Warn(ctx, "Rejecting vote extension exceeding limit", nil, "count", len(votes.Votes), "limit", k.voteExtLimit)
return respReject, nil
return nil, false, errors.New("vote extension limit exceeded", "count", len(votes.Votes), "limit", k.voteExtLimit)
}

duplicate := make(map[common.Hash]bool) // Detect identical duplicate votes (same AttestationRoot)
doubleSign := make(map[xchain.AttestHeader]bool) // Detect double sign votes (same AttestHeader)
var resp []*types.Vote
for _, vote := range votes.Votes {
if err := vote.Verify(); err != nil {
log.Warn(ctx, "Rejecting invalid vote", err)
return respReject, nil
return nil, false, errors.Wrap(err, "verify vote")
}

attRoot, err := vote.AttestationRoot()
if err != nil {
return nil, errors.Wrap(err, "att root [BUG]") // Should error in Verify
return nil, false, errors.Wrap(err, "att root")
}
if duplicate[attRoot] {
log.Warn(ctx, "Rejecting duplicate identical vote", nil)
return respReject, nil
return nil, false, errors.New("duplicate identical vote")
}
duplicate[attRoot] = true

if doubleSign[vote.AttestHeader.ToXChain()] {
doubleSignCounter.WithLabelValues(ethAddr.Hex()).Inc()
log.Warn(ctx, "Rejecting duplicate slashable vote", err)

return respReject, nil
return nil, false, errors.New("duplicate slashable vote")
}
doubleSign[vote.AttestHeader.ToXChain()] = true

// Ensure the votes are from the requesting validator itself.
if !bytes.Equal(vote.Signature.ValidatorAddress, ethAddr[:]) {
log.Warn(ctx, "Rejecting mismatching vote and req validator address", nil, "vote", ethAddr, "req", req.ValidatorAddress)
return respReject, nil
return nil, false, errors.New("mismatching vote and req validator address", "vote", ethAddr, "req", vote.Signature.ValidatorAddress)
}

if err := verifyHeaderChains(ctx, cChainID, k.portalRegistry, vote.AttestHeader, vote.BlockHeader); err != nil {
log.Warn(ctx, "Rejecting vote for invalid header chains", err, "chain", k.namer(vote.AttestHeader.XChainVersion()))
return respReject, nil
return nil, false, errors.Wrap(err, "verify chain headers", "chain", k.namer(vote.AttestHeader.XChainVersion()))
}

if cmp, err := k.windowCompare(ctx, vote.AttestHeader.XChainVersion(), vote.AttestHeader.AttestOffset); err != nil {
return nil, errors.Wrap(err, "windower")
} else if cmp != 0 {
log.Warn(ctx, "Rejecting out-of-window vote", nil, "cmp", cmp)
return respReject, nil
return nil, false, errors.Wrap(err, "window compare")
} else if cmp > 0 {
return nil, false, errors.New("vote ahead of window")
} else if cmp < 0 {
// Vote-behind-window is expected in PrepareProposal, just don't add to response.
continue
}

resp = append(resp, vote)
}

return respAccept, nil
return resp, len(votes.Votes) == len(resp), nil
}

type ValSet struct {
Expand Down Expand Up @@ -875,7 +896,7 @@ func (k *Keeper) verifyAggVotes(
cChainID uint64,
valset ValSet,
aggs []*types.AggVote,
windowCompareFunc windowCompareFunc, // Aliased for testing
windowCompareFunc func(context.Context, xchain.ChainVersion, uint64) (int, error), // Aliased for testing
) error {
duplicate := make(map[common.Hash]bool) // Detects duplicate aggregate votes.
countsPerVal := make(map[common.Address]uint64) // Enforce vote extension limit.
Expand Down
Loading
Loading