Skip to content

Commit d643d2d

Browse files
authored
perf: compute all committees at once at the beginning of an epoch. (#1245)
1 parent 4c814a3 commit d643d2d

File tree

15 files changed

+297
-49
lines changed

15 files changed

+297
-49
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ callgrind.out.*
6565

6666
# beacon node oapi json file
6767
beacon-node-oapi.json
68+
flamegraphs/

bench/block_processing.exs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,42 @@
11
alias LambdaEthereumConsensus.ForkChoice
22
alias LambdaEthereumConsensus.ForkChoice.Handlers
33
alias LambdaEthereumConsensus.StateTransition.Cache
4-
alias LambdaEthereumConsensus.Store
5-
alias LambdaEthereumConsensus.Store.BlockBySlot
64
alias LambdaEthereumConsensus.Store.BlockDb
75
alias LambdaEthereumConsensus.Store.StateDb
8-
alias Types.BeaconState
96
alias Types.BlockInfo
10-
alias Types.SignedBeaconBlock
117
alias Types.StateInfo
8+
alias Utils.Date
129

1310
Logger.configure(level: :warning)
1411
Cache.initialize_cache()
1512

1613
# NOTE: this slot must be at the beginning of an epoch (i.e. a multiple of 32)
17-
slot = 9_591_424
14+
slot = 9_649_056
1815

19-
IO.puts("fetching blocks...")
16+
IO.puts("Fetching state and blocks...")
2017
{:ok, %StateInfo{beacon_state: state}} = StateDb.get_state_by_slot(slot)
2118
{:ok, %BlockInfo{signed_block: block}} = BlockDb.get_block_info_by_slot(slot)
22-
{:ok, %BlockInfo{signed_block: new_block} = block_info} = BlockDb.get_block_info_by_slot(slot + 1)
19+
{:ok, %BlockInfo{} = block_info} = BlockDb.get_block_info_by_slot(slot + 1)
20+
{:ok, %BlockInfo{} = block_info_2} = BlockDb.get_block_info_by_slot(slot + 2)
2321

24-
IO.puts("initializing store...")
22+
IO.puts("Initializing store...")
2523
{:ok, store} = Types.Store.get_forkchoice_store(state, block)
2624
store = Handlers.on_tick(store, store.time + 30)
2725

28-
{:ok, root} = BlockBySlot.get(slot)
26+
IO.puts("Processing the block 1...")
2927

30-
IO.puts("about to process block: #{slot + 1}, with root: #{Base.encode16(root)}...")
31-
IO.puts("#{length(attestations)} attestations ; #{length(attester_slashings)} attester slashings")
32-
IO.puts("")
28+
{:ok, new_store} = ForkChoice.process_block(block_info, store)
29+
IO.puts("Processing the block 2...")
3330

3431
if System.get_env("FLAMA") do
35-
Flama.run({ForkChoice, :process_block, [block_info, store]})
32+
filename = "flamegraphs/stacks.#{Date.now_str()}.out"
33+
Flama.run({ForkChoice, :process_block, [block_info_2, new_store]}, output_file: filename)
34+
IO.puts("Flamegraph saved to #{filename}")
3635
else
3736
Benchee.run(
3837
%{
3938
"block (full cache)" => fn ->
40-
ForkChoice.process_block(block_info, store)
39+
ForkChoice.process_block(block_info_2, new_store)
4140
end
4241
},
4342
time: 30
@@ -46,7 +45,7 @@ else
4645
Benchee.run(
4746
%{
4847
"block (empty cache)" => fn _ ->
49-
ForkChoice.process_block(block_info, store)
48+
ForkChoice.process_block(block_info_2, new_store)
5049
end
5150
},
5251
time: 30,

lib/lambda_ethereum_consensus/fork_choice/fork_choice.ex

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ defmodule LambdaEthereumConsensus.ForkChoice do
1010
alias LambdaEthereumConsensus.Libp2pPort
1111
alias LambdaEthereumConsensus.Metrics
1212
alias LambdaEthereumConsensus.P2P.Gossip.OperationsCollector
13+
alias LambdaEthereumConsensus.StateTransition.Accessors
1314
alias LambdaEthereumConsensus.StateTransition.Misc
1415
alias LambdaEthereumConsensus.Store.BlobDb
1516
alias LambdaEthereumConsensus.Store.BlockDb
@@ -209,13 +210,35 @@ defmodule LambdaEthereumConsensus.ForkChoice do
209210
attestations = signed_block.message.body.attestations
210211
attester_slashings = signed_block.message.body.attester_slashings
211212

213+
# Prefetch relevant states.
214+
states =
215+
Metrics.span_operation(:prefetch_states, nil, nil, fn ->
216+
attestations
217+
|> Enum.map(& &1.data.target)
218+
|> Enum.uniq()
219+
|> Enum.flat_map(&fetch_checkpoint_state/1)
220+
|> Map.new()
221+
end)
222+
223+
# Prefetch committees for all relevant epochs.
224+
Metrics.span_operation(:prefetch_committees, nil, nil, fn ->
225+
Enum.each(states, fn {ch, state} -> Accessors.maybe_prefetch_committees(state, ch.epoch) end)
226+
end)
227+
212228
with {:ok, new_store} <- apply_on_block(store, block_info),
213-
{:ok, new_store} <- process_attestations(new_store, attestations),
229+
{:ok, new_store} <- process_attestations(new_store, attestations, states),
214230
{:ok, new_store} <- process_attester_slashings(new_store, attester_slashings) do
215231
{:ok, new_store}
216232
end
217233
end
218234

235+
def fetch_checkpoint_state(checkpoint) do
236+
case CheckpointStates.get_checkpoint_state(checkpoint) do
237+
{:ok, state} -> [{checkpoint, state}]
238+
_other -> []
239+
end
240+
end
241+
219242
defp apply_on_block(store, block_info) do
220243
Metrics.span_operation(:on_block, nil, nil, fn -> Handlers.on_block(store, block_info) end)
221244
end
@@ -226,29 +249,16 @@ defmodule LambdaEthereumConsensus.ForkChoice do
226249
end)
227250
end
228251

229-
defp process_attestations(store, attestations) do
252+
defp process_attestations(store, attestations, states) do
230253
Metrics.span_operation(:attestations, nil, nil, fn ->
231254
apply_handler(
232255
attestations,
233256
store,
234-
&Handlers.on_attestation(&1, &2, true, prefetch_states(attestations))
257+
&Handlers.on_attestation(&1, &2, true, states)
235258
)
236259
end)
237260
end
238261

239-
defp prefetch_states(attestations) do
240-
attestations
241-
|> Enum.map(& &1.data.target)
242-
|> Enum.uniq()
243-
|> Enum.flat_map(fn ch ->
244-
case CheckpointStates.get_checkpoint_state(ch) do
245-
{:ok, state} -> [{ch, state}]
246-
_other -> []
247-
end
248-
end)
249-
|> Map.new()
250-
end
251-
252262
@spec recompute_head(Store.t()) :: :ok
253263
def recompute_head(store) do
254264
{:ok, head_root} = Head.get_head(store)

lib/lambda_ethereum_consensus/state_transition/accessors.ex

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
33
Functions accessing the current `BeaconState`
44
"""
55

6+
require Logger
67
alias LambdaEthereumConsensus.StateTransition.Cache
78
alias LambdaEthereumConsensus.StateTransition.Math
89
alias LambdaEthereumConsensus.StateTransition.Misc
@@ -281,7 +282,12 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
281282
end
282283

283284
@doc """
284-
Return the number of committees in each slot for the given ``epoch``.
285+
Returns the number of committees in each slot for the given ``epoch``.
286+
287+
The amount of committees is (using integer division):
288+
active_validator_count / slots_per_epoch / TARGET_COMMITTEE_SIZE
289+
290+
The amount of committees will be capped between 1 and MAX_COMMITTEES_PER_SLOT.
285291
"""
286292
@spec get_committee_count_per_slot(BeaconState.t(), Types.epoch()) :: Types.uint64()
287293
def get_committee_count_per_slot(%BeaconState{} = state, epoch) do
@@ -300,7 +306,16 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
300306
end
301307

302308
@doc """
303-
Return the beacon committee at ``slot`` for ``index``.
309+
Returns the beacon committee at ``slot`` for ``index``.
310+
- slot is the one for which the committee is being calculated. Typically the slot of an
311+
attestation. Might be different from the state slot.
312+
- index: the index of the committee within the slot. It's not the committee index, which is the
313+
index of the committee within the epoch. This transformation is done internally.
314+
315+
The beacon committee returned is a list of global validator indices that should participate in
316+
the requested slot. The order in which the indices are sorted is the same as the one used in
317+
aggregation bits, so checking if the nth member of a committee participated is as simple as
318+
checking if the nth bit is set.
304319
"""
305320
@spec get_beacon_committee(BeaconState.t(), Types.slot(), Types.commitee_index()) ::
306321
{:ok, [Types.validator_index()]} | {:error, String.t()}
@@ -327,6 +342,41 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
327342
end
328343
end
329344

345+
@doc """
346+
Computes all committees for a single epoch and saves them in the cache. This only happens if the
347+
value is not calculated and if the root for the epoch is available. If any of those conditions
348+
is not true, this function is a noop.
349+
350+
Arguments:
351+
- state: state used to get active validators, seed and others. Any state that is within the same
352+
epoch is equivalent, as validators are updated in epoch boundaries.
353+
- epoch: epoch for which the committees are calculated.
354+
"""
355+
def maybe_prefetch_committees(state, epoch) do
356+
first_slot = Misc.compute_start_slot_at_epoch(epoch)
357+
358+
with {:ok, root} <- get_epoch_root(state, epoch),
359+
false <- Cache.present?(:beacon_committee, {first_slot, {0, root}}) do
360+
Logger.info("[Block processing] Computing committees for epoch #{epoch}")
361+
362+
committees_per_slot = get_committee_count_per_slot(state, epoch)
363+
364+
Misc.compute_all_committees(state, epoch)
365+
|> Enum.with_index()
366+
|> Enum.each(fn {committee, i} ->
367+
# The how do we know for which slot is a committee
368+
slot = first_slot + div(i, committees_per_slot)
369+
index = rem(i, committees_per_slot)
370+
371+
Cache.set(
372+
:beacon_committee,
373+
{slot, {index, root}},
374+
{:ok, committee |> Aja.Enum.to_list()}
375+
)
376+
end)
377+
end
378+
end
379+
330380
@spec get_base_reward_per_increment(BeaconState.t()) :: Types.gwei()
331381
def get_base_reward_per_increment(state) do
332382
numerator = ChainSpec.get("EFFECTIVE_BALANCE_INCREMENT") * ChainSpec.get("BASE_REWARD_FACTOR")
@@ -505,6 +555,10 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
505555

506556
@doc """
507557
Return the set of attesting indices corresponding to ``data`` and ``bits``.
558+
559+
It computes the committee for the attestation (indices of validators that should participate in
560+
that slot) and then filters the ones that actually participated. It returns an unordered MapSet,
561+
which is useful for checking inclusion, but should be ordered if used to validate an attestation.
508562
"""
509563
@spec get_attesting_indices(BeaconState.t(), Types.AttestationData.t(), Types.bitlist()) ::
510564
{:ok, MapSet.t()} | {:error, String.t()}

lib/lambda_ethereum_consensus/state_transition/cache.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Cache do
7373
match_spec = generate_cleanup_spec(table, key)
7474
:ets.select_delete(table, match_spec)
7575
end
76+
77+
def present?(table, key), do: :ets.member(table, key)
78+
def set(table, key, value), do: :ets.insert_new(table, {key, value})
7679
end

lib/lambda_ethereum_consensus/state_transition/misc.ex

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do
55

66
import Bitwise
77
require Aja
8+
require Logger
89

10+
alias LambdaEthereumConsensus.StateTransition.Accessors
11+
alias LambdaEthereumConsensus.StateTransition.Shuffling
12+
alias LambdaEthereumConsensus.Utils
913
alias Types.BeaconState
1014

1115
@max_random_byte 2 ** 8 - 1
@@ -180,9 +184,60 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do
180184
<<value::unsigned-integer-little-size(64)>>
181185
end
182186

187+
@doc """
188+
Gets all committees for a single epoch. More efficient than calculating each one, as the shuffling
189+
is done a single time for the whole index list and shared values are reused between committees.
190+
"""
191+
@spec compute_all_committees(BeaconState.t(), Types.epoch()) :: list(Aja.Vector.t())
192+
def compute_all_committees(state, epoch) do
193+
indices = Accessors.get_active_validator_indices(state, epoch)
194+
index_count = Aja.Vector.size(indices)
195+
seed = Accessors.get_seed(state, epoch, Constants.domain_beacon_attester())
196+
197+
shuffled_indices = Shuffling.shuffle_list(indices, seed) |> Aja.Vector.to_list()
198+
199+
committee_count =
200+
Accessors.get_committee_count_per_slot(state, epoch) * ChainSpec.get("SLOTS_PER_EPOCH")
201+
202+
committee_sizes =
203+
Enum.map(0..(committee_count - 1), fn index ->
204+
{c_start, c_end} = committee_boundaries(index, index_count, committee_count)
205+
c_end - c_start + 1
206+
end)
207+
208+
# separate using sizes.
209+
Utils.chunk_by_sizes(shuffled_indices, committee_sizes)
210+
end
211+
183212
@doc """
184213
Computes the validator indices of the ``committee_index``-th committee at some epoch
185214
with ``committee_count`` committees, and for some given ``indices`` and ``seed``.
215+
216+
Args:
217+
- indices: a full list of all active validator indices for a single epoch.
218+
- seed: for shuffling calculations.
219+
- committee_index: global number representing the order of the requested committee within the
220+
whole epoch.
221+
- committee_count: total amount of committees for the epoch. Useful to determine the start and end
222+
of the requested committee.
223+
224+
Returns:
225+
- The list of indices for the validators that conform the requested committee. The order is the
226+
same as used in the aggregation bits of an attestation in that committee.
227+
228+
PERFORMANCE NOTE:
229+
230+
Instead of shuffling the full index list, it focuses on the positions of the requested committee
231+
and calculates their shuffled index. Because of the symmetric nature of the shuffling algorithm,
232+
looking at the shuffled index position in the index list gives the element that would end up in
233+
the committee if the full list was to be shuffled.
234+
235+
This is, in logic, equivalent to shuffling the whole validator index list and getting the
236+
elements for the committee under calculation, but only calculating the shuffling for the elements
237+
of the committee.
238+
239+
While the amount of calculations is smaller than the full shuffling, calling this for every
240+
committee in an epoch is inefficient. For that end, compute_all_committees should be called.
186241
"""
187242
@spec compute_committee(Aja.Vector.t(), Types.bytes32(), Types.uint64(), Types.uint64()) ::
188243
{:error, String.t()}
@@ -197,8 +252,9 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do
197252
def compute_committee(indices, seed, committee_index, committee_count)
198253
when committee_index < committee_count do
199254
index_count = Aja.Vector.size(indices)
200-
committee_start = div(index_count * committee_index, committee_count)
201-
committee_end = div(index_count * (committee_index + 1), committee_count) - 1
255+
256+
{committee_start, committee_end} =
257+
committee_boundaries(committee_index, index_count, committee_count)
202258

203259
committee_start..committee_end//1
204260
# NOTE: this cannot fail because committee_end < index_count
@@ -211,6 +267,20 @@ defmodule LambdaEthereumConsensus.StateTransition.Misc do
211267

212268
def compute_committee(_, _, _, _), do: {:error, "Invalid committee index"}
213269

270+
@doc """
271+
Computes the boundaries of a committee.
272+
273+
Args:
274+
- committee_index: epoch based committee index.
275+
- index_count: amount of active validators participating in the epoch.
276+
- committee_count: amount of committees that will be formed in the epoch.
277+
"""
278+
def committee_boundaries(committee_index, index_count, committee_count) do
279+
committee_start = div(index_count * committee_index, committee_count)
280+
committee_end = div(index_count * (committee_index + 1), committee_count) - 1
281+
{committee_start, committee_end}
282+
end
283+
214284
@doc """
215285
Return the 32-byte fork data root for the ``current_version`` and ``genesis_validators_root``.
216286
This is used primarily in signature domains to avoid collisions across forks/chains.

lib/lambda_ethereum_consensus/state_transition/operations.ex

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,10 +847,14 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
847847
end
848848

849849
defp check_matching_aggregation_bits_length(attestation, beacon_committee) do
850-
if BitList.length(attestation.aggregation_bits) == length(beacon_committee) do
850+
aggregation_bits_length = BitList.length(attestation.aggregation_bits)
851+
beacon_committee_length = length(beacon_committee)
852+
853+
if aggregation_bits_length == beacon_committee_length do
851854
:ok
852855
else
853-
{:error, "Mismatched aggregation bits length"}
856+
{:error,
857+
"Mismatched length. aggregation_bits: #{aggregation_bits_length}. beacon_committee: #{beacon_committee_length}"}
854858
end
855859
end
856860

0 commit comments

Comments
 (0)