Skip to content

perf: compute all committees at once at the beginning of an epoch. #1245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ callgrind.out.*

# beacon node oapi json file
beacon-node-oapi.json
flamegraphs/
29 changes: 14 additions & 15 deletions bench/block_processing.exs
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
alias LambdaEthereumConsensus.ForkChoice
alias LambdaEthereumConsensus.ForkChoice.Handlers
alias LambdaEthereumConsensus.StateTransition.Cache
alias LambdaEthereumConsensus.Store
alias LambdaEthereumConsensus.Store.BlockBySlot
alias LambdaEthereumConsensus.Store.BlockDb
alias LambdaEthereumConsensus.Store.StateDb
alias Types.BeaconState
alias Types.BlockInfo
alias Types.SignedBeaconBlock
alias Types.StateInfo
alias Utils.Date

Logger.configure(level: :warning)
Cache.initialize_cache()

# NOTE: this slot must be at the beginning of an epoch (i.e. a multiple of 32)
slot = 9_591_424
slot = 9_649_056

IO.puts("fetching blocks...")
IO.puts("Fetching state and blocks...")
{:ok, %StateInfo{beacon_state: state}} = StateDb.get_state_by_slot(slot)
{:ok, %BlockInfo{signed_block: block}} = BlockDb.get_block_info_by_slot(slot)
{:ok, %BlockInfo{signed_block: new_block} = block_info} = BlockDb.get_block_info_by_slot(slot + 1)
{:ok, %BlockInfo{} = block_info} = BlockDb.get_block_info_by_slot(slot + 1)
{:ok, %BlockInfo{} = block_info_2} = BlockDb.get_block_info_by_slot(slot + 2)

IO.puts("initializing store...")
IO.puts("Initializing store...")
{:ok, store} = Types.Store.get_forkchoice_store(state, block)
store = Handlers.on_tick(store, store.time + 30)

{:ok, root} = BlockBySlot.get(slot)
IO.puts("Processing the block 1...")

IO.puts("about to process block: #{slot + 1}, with root: #{Base.encode16(root)}...")
IO.puts("#{length(attestations)} attestations ; #{length(attester_slashings)} attester slashings")
IO.puts("")
{:ok, new_store} = ForkChoice.process_block(block_info, store)
IO.puts("Processing the block 2...")

if System.get_env("FLAMA") do
Flama.run({ForkChoice, :process_block, [block_info, store]})
filename = "flamegraphs/stacks.#{Date.now_str()}.out"
Flama.run({ForkChoice, :process_block, [block_info_2, new_store]}, output_file: filename)
IO.puts("Flamegraph saved to #{filename}")
else
Benchee.run(
%{
"block (full cache)" => fn ->
ForkChoice.process_block(block_info, store)
ForkChoice.process_block(block_info_2, new_store)
end
},
time: 30
Expand All @@ -46,7 +45,7 @@ else
Benchee.run(
%{
"block (empty cache)" => fn _ ->
ForkChoice.process_block(block_info, store)
ForkChoice.process_block(block_info_2, new_store)
end
},
time: 30,
Expand Down
42 changes: 26 additions & 16 deletions lib/lambda_ethereum_consensus/fork_choice/fork_choice.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ defmodule LambdaEthereumConsensus.ForkChoice do
alias LambdaEthereumConsensus.Libp2pPort
alias LambdaEthereumConsensus.Metrics
alias LambdaEthereumConsensus.P2P.Gossip.OperationsCollector
alias LambdaEthereumConsensus.StateTransition.Accessors
alias LambdaEthereumConsensus.StateTransition.Misc
alias LambdaEthereumConsensus.Store.BlobDb
alias LambdaEthereumConsensus.Store.BlockDb
Expand Down Expand Up @@ -209,13 +210,35 @@ defmodule LambdaEthereumConsensus.ForkChoice do
attestations = signed_block.message.body.attestations
attester_slashings = signed_block.message.body.attester_slashings

# Prefetch relevant states.
states =
Metrics.span_operation(:prefetch_states, nil, nil, fn ->
attestations
|> Enum.map(& &1.data.target)
|> Enum.uniq()
|> Enum.flat_map(&fetch_checkpoint_state/1)
|> Map.new()
end)

# Prefetch committees for all relevant epochs.
Metrics.span_operation(:prefetch_committees, nil, nil, fn ->
Enum.each(states, fn {ch, state} -> Accessors.maybe_prefetch_committees(state, ch.epoch) end)
end)

with {:ok, new_store} <- apply_on_block(store, block_info),
{:ok, new_store} <- process_attestations(new_store, attestations),
{:ok, new_store} <- process_attestations(new_store, attestations, states),
{:ok, new_store} <- process_attester_slashings(new_store, attester_slashings) do
{:ok, new_store}
end
end

def fetch_checkpoint_state(checkpoint) do
case CheckpointStates.get_checkpoint_state(checkpoint) do
{:ok, state} -> [{checkpoint, state}]
_other -> []
end
end

defp apply_on_block(store, block_info) do
Metrics.span_operation(:on_block, nil, nil, fn -> Handlers.on_block(store, block_info) end)
end
Expand All @@ -226,29 +249,16 @@ defmodule LambdaEthereumConsensus.ForkChoice do
end)
end

defp process_attestations(store, attestations) do
defp process_attestations(store, attestations, states) do
Metrics.span_operation(:attestations, nil, nil, fn ->
apply_handler(
attestations,
store,
&Handlers.on_attestation(&1, &2, true, prefetch_states(attestations))
&Handlers.on_attestation(&1, &2, true, states)
)
end)
end

defp prefetch_states(attestations) do
attestations
|> Enum.map(& &1.data.target)
|> Enum.uniq()
|> Enum.flat_map(fn ch ->
case CheckpointStates.get_checkpoint_state(ch) do
{:ok, state} -> [{ch, state}]
_other -> []
end
end)
|> Map.new()
end

@spec recompute_head(Store.t()) :: :ok
def recompute_head(store) do
{:ok, head_root} = Head.get_head(store)
Expand Down
58 changes: 56 additions & 2 deletions lib/lambda_ethereum_consensus/state_transition/accessors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
Functions accessing the current `BeaconState`
"""

require Logger
alias LambdaEthereumConsensus.StateTransition.Cache
alias LambdaEthereumConsensus.StateTransition.Math
alias LambdaEthereumConsensus.StateTransition.Misc
Expand Down Expand Up @@ -281,7 +282,12 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
end

@doc """
Return the number of committees in each slot for the given ``epoch``.
Returns the number of committees in each slot for the given ``epoch``.

The amount of committees is (using integer division):
active_validator_count / slots_per_epoch / TARGET_COMMITTEE_SIZE

The amount of committees will be capped between 1 and MAX_COMMITTEES_PER_SLOT.
"""
@spec get_committee_count_per_slot(BeaconState.t(), Types.epoch()) :: Types.uint64()
def get_committee_count_per_slot(%BeaconState{} = state, epoch) do
Expand All @@ -300,7 +306,16 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
end

@doc """
Return the beacon committee at ``slot`` for ``index``.
Returns the beacon committee at ``slot`` for ``index``.
- slot is the one for which the committee is being calculated. Typically the slot of an
attestation. Might be different from the state slot.
- index: the index of the committee within the slot. It's not the committee index, which is the
index of the committee within the epoch. This transformation is done internally.

The beacon committee returned is a list of global validator indices that should participate in
the requested slot. The order in which the indices are sorted is the same as the one used in
aggregation bits, so checking if the nth member of a committee participated is as simple as
checking if the nth bit is set.
"""
@spec get_beacon_committee(BeaconState.t(), Types.slot(), Types.commitee_index()) ::
{:ok, [Types.validator_index()]} | {:error, String.t()}
Expand All @@ -327,6 +342,41 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
end
end

@doc """
Computes all committees for a single epoch and saves them in the cache. This only happens if the
value is not calculated and if the root for the epoch is available. If any of those conditions
is not true, this function is a noop.

Arguments:
- state: state used to get active validators, seed and others. Any state that is within the same
epoch is equivalent, as validators are updated in epoch boundaries.
- epoch: epoch for which the committees are calculated.
"""
def maybe_prefetch_committees(state, epoch) do
first_slot = Misc.compute_start_slot_at_epoch(epoch)

with {:ok, root} <- get_epoch_root(state, epoch),
false <- Cache.present?(:beacon_committee, {first_slot, {0, root}}) do
Logger.info("[Block processing] Computing committees for epoch #{epoch}")

committees_per_slot = get_committee_count_per_slot(state, epoch)

Misc.compute_all_committees(state, epoch)
|> Enum.with_index()
|> Enum.each(fn {committee, i} ->
# The how do we know for which slot is a committee
slot = first_slot + div(i, committees_per_slot)
index = rem(i, committees_per_slot)

Cache.set(
:beacon_committee,
{slot, {index, root}},
{:ok, committee |> Aja.Enum.to_list()}
)
end)
end
end

@spec get_base_reward_per_increment(BeaconState.t()) :: Types.gwei()
def get_base_reward_per_increment(state) do
numerator = ChainSpec.get("EFFECTIVE_BALANCE_INCREMENT") * ChainSpec.get("BASE_REWARD_FACTOR")
Expand Down Expand Up @@ -505,6 +555,10 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do

@doc """
Return the set of attesting indices corresponding to ``data`` and ``bits``.

It computes the committee for the attestation (indices of validators that should participate in
that slot) and then filters the ones that actually participated. It returns an unordered MapSet,
which is useful for checking inclusion, but should be ordered if used to validate an attestation.
"""
@spec get_attesting_indices(BeaconState.t(), Types.AttestationData.t(), Types.bitlist()) ::
{:ok, MapSet.t()} | {:error, String.t()}
Expand Down
3 changes: 3 additions & 0 deletions lib/lambda_ethereum_consensus/state_transition/cache.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Cache do
match_spec = generate_cleanup_spec(table, key)
:ets.select_delete(table, match_spec)
end

def present?(table, key), do: :ets.member(table, key)
def set(table, key, value), do: :ets.insert_new(table, {key, value})
end
74 changes: 72 additions & 2 deletions lib/lambda_ethereum_consensus/state_transition/misc.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do

import Bitwise
require Aja
require Logger

alias LambdaEthereumConsensus.StateTransition.Accessors
alias LambdaEthereumConsensus.StateTransition.Shuffling
alias LambdaEthereumConsensus.Utils
alias Types.BeaconState

@max_random_byte 2 ** 8 - 1
Expand Down Expand Up @@ -180,9 +184,60 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do
<<value::unsigned-integer-little-size(64)>>
end

@doc """
Gets all committees for a single epoch. More efficient than calculating each one, as the shuffling
is done a single time for the whole index list and shared values are reused between committees.
"""
@spec compute_all_committees(BeaconState.t(), Types.epoch()) :: list(Aja.Vector.t())
def compute_all_committees(state, epoch) do
indices = Accessors.get_active_validator_indices(state, epoch)
index_count = Aja.Vector.size(indices)
seed = Accessors.get_seed(state, epoch, Constants.domain_beacon_attester())

shuffled_indices = Shuffling.shuffle_list(indices, seed) |> Aja.Vector.to_list()

committee_count =
Accessors.get_committee_count_per_slot(state, epoch) * ChainSpec.get("SLOTS_PER_EPOCH")

committee_sizes =
Enum.map(0..(committee_count - 1), fn index ->
{c_start, c_end} = committee_boundaries(index, index_count, committee_count)
c_end - c_start + 1
end)

# separate using sizes.
Utils.chunk_by_sizes(shuffled_indices, committee_sizes)
end

@doc """
Computes the validator indices of the ``committee_index``-th committee at some epoch
with ``committee_count`` committees, and for some given ``indices`` and ``seed``.

Args:
- indices: a full list of all active validator indices for a single epoch.
- seed: for shuffling calculations.
- committee_index: global number representing the order of the requested committee within the
whole epoch.
- committee_count: total amount of committees for the epoch. Useful to determine the start and end
of the requested committee.

Returns:
- The list of indices for the validators that conform the requested committee. The order is the
same as used in the aggregation bits of an attestation in that committee.

PERFORMANCE NOTE:

Instead of shuffling the full index list, it focuses on the positions of the requested committee
and calculates their shuffled index. Because of the symmetric nature of the shuffling algorithm,
looking at the shuffled index position in the index list gives the element that would end up in
the committee if the full list was to be shuffled.

This is, in logic, equivalent to shuffling the whole validator index list and getting the
elements for the committee under calculation, but only calculating the shuffling for the elements
of the committee.

While the amount of calculations is smaller than the full shuffling, calling this for every
committee in an epoch is inefficient. For that end, compute_all_committees should be called.
"""
@spec compute_committee(Aja.Vector.t(), Types.bytes32(), Types.uint64(), Types.uint64()) ::
{:error, String.t()}
Expand All @@ -197,8 +252,9 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do
def compute_committee(indices, seed, committee_index, committee_count)
when committee_index < committee_count do
index_count = Aja.Vector.size(indices)
committee_start = div(index_count * committee_index, committee_count)
committee_end = div(index_count * (committee_index + 1), committee_count) - 1

{committee_start, committee_end} =
committee_boundaries(committee_index, index_count, committee_count)

committee_start..committee_end//1
# NOTE: this cannot fail because committee_end < index_count
Expand All @@ -211,6 +267,20 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do

def compute_committee(_, _, _, _), do: {:error, "Invalid committee index"}

@doc """
Computes the boundaries of a committee.

Args:
- committee_index: epoch based committee index.
- index_count: amount of active validators participating in the epoch.
- committee_count: amount of committees that will be formed in the epoch.
"""
def committee_boundaries(committee_index, index_count, committee_count) do
committee_start = div(index_count * committee_index, committee_count)
committee_end = div(index_count * (committee_index + 1), committee_count) - 1
{committee_start, committee_end}
end

@doc """
Return the 32-byte fork data root for the ``current_version`` and ``genesis_validators_root``.
This is used primarily in signature domains to avoid collisions across forks/chains.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,10 +847,14 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
end

defp check_matching_aggregation_bits_length(attestation, beacon_committee) do
if BitList.length(attestation.aggregation_bits) == length(beacon_committee) do
aggregation_bits_length = BitList.length(attestation.aggregation_bits)
beacon_committee_length = length(beacon_committee)

if aggregation_bits_length == beacon_committee_length do
:ok
else
{:error, "Mismatched aggregation bits length"}
{:error,
"Mismatched length. aggregation_bits: #{aggregation_bits_length}. beacon_committee: #{beacon_committee_length}"}
end
end

Expand Down
Loading
Loading