Skip to content

Commit 67a7aaf

Browse files
committed
Merge branch 'main' of github.com:lambdaclass/lambda_ethereum_consensus into move-runners-to-test
2 parents 1588ee9 + 84a83a7 commit 67a7aaf

File tree

5 files changed

+275
-30
lines changed

5 files changed

+275
-30
lines changed

bench/ssz.exs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,26 @@ Benchee.run(
3434
warmup: 2,
3535
time: 5
3636
)
37+
38+
## Benchmark Merkleization
39+
40+
list = Stream.cycle([65_535]) |> Enum.take(316)
41+
schema = {:list, {:int, 16}, 1024}
42+
packed_chunks = SszEx.pack(list, schema)
43+
limit = SszEx.chunk_count(schema)
44+
45+
Benchee.run(
46+
%{
47+
"SszEx.merkleize_chunks" => fn {chunks, leaf_count} ->
48+
SszEx.merkleize_chunks(chunks, leaf_count)
49+
end,
50+
"SszEx.merkleize_chunks_with_virtual_padding" => fn {chunks, leaf_count} ->
51+
SszEx.merkleize_chunks_with_virtual_padding(chunks, leaf_count)
52+
end
53+
},
54+
inputs: %{
55+
"args" => {packed_chunks, limit}
56+
},
57+
warmup: 2,
58+
time: 5
59+
)

lib/ssz_ex.ex

Lines changed: 148 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ defmodule LambdaEthereumConsensus.SszEx do
55
alias LambdaEthereumConsensus.Utils.BitList
66
alias LambdaEthereumConsensus.Utils.BitVector
77
import alias LambdaEthereumConsensus.Utils.BitVector
8+
alias LambdaEthereumConsensus.Utils.ZeroHashes
89

910
#################
1011
### Public API
@@ -15,6 +16,7 @@ defmodule LambdaEthereumConsensus.SszEx do
1516
@bits_per_byte 8
1617
@bits_per_chunk @bytes_per_chunk * @bits_per_byte
1718
@zero_chunk <<0::size(@bits_per_chunk)>>
19+
@zero_hashes ZeroHashes.compute_zero_hashes()
1820

1921
@spec hash(iodata()) :: binary()
2022
def hash(data), do: :crypto.hash(:sha256, data)
@@ -73,6 +75,41 @@ defmodule LambdaEthereumConsensus.SszEx do
7375
@spec hash_tree_root!(non_neg_integer, {:int, non_neg_integer}) :: Types.root()
7476
def hash_tree_root!(value, {:int, size}), do: pack(value, {:int, size})
7577

78+
@spec hash_tree_root!(binary, {:bytes, non_neg_integer}) :: Types.root()
79+
def hash_tree_root!(value, {:bytes, size}) do
80+
packed_chunks = pack(value, {:bytes, size})
81+
leaf_count = packed_chunks |> get_chunks_len() |> next_pow_of_two()
82+
root = merkleize_chunks_with_virtual_padding(packed_chunks, leaf_count)
83+
root
84+
end
85+
86+
@spec hash_tree_root!(list(), {:list, any, non_neg_integer}) :: Types.root()
87+
def hash_tree_root!(list, {:list, type, size}) do
88+
{:ok, root} = hash_tree_root(list, {:list, type, size})
89+
root
90+
end
91+
92+
@spec hash_tree_root!(list(), {:vector, any, non_neg_integer}) :: Types.root()
93+
def hash_tree_root!(vector, {:vector, type, size}) do
94+
{:ok, root} = hash_tree_root(vector, {:vector, type, size})
95+
root
96+
end
97+
98+
@spec hash_tree_root!(struct(), atom()) :: Types.root()
99+
def hash_tree_root!(container, module) when is_map(container) do
100+
chunks =
101+
module.schema()
102+
|> Enum.reduce(<<>>, fn {key, schema}, acc_root ->
103+
value = container |> Map.get(key)
104+
root = hash_tree_root!(value, schema)
105+
acc_root <> root
106+
end)
107+
108+
leaf_count = chunks |> get_chunks_len() |> next_pow_of_two()
109+
root = merkleize_chunks_with_virtual_padding(chunks, leaf_count)
110+
root
111+
end
112+
76113
@spec hash_tree_root(list(), {:list, any, non_neg_integer}) ::
77114
{:ok, Types.root()} | {:error, String.t()}
78115
def hash_tree_root(list, {:list, type, size}) do
@@ -109,7 +146,7 @@ defmodule LambdaEthereumConsensus.SszEx do
109146
if chunks_len > limit do
110147
{:error, "chunk size exceeds limit"}
111148
else
112-
root = merkleize_chunks(chunks, limit) |> mix_in_length(len)
149+
root = merkleize_chunks_with_virtual_padding(chunks, limit) |> mix_in_length(len)
113150
{:ok, root}
114151
end
115152
end
@@ -118,7 +155,7 @@ defmodule LambdaEthereumConsensus.SszEx do
118155
{:ok, Types.root()} | {:error, String.t()}
119156
def hash_tree_root_vector_basic_type(chunks) do
120157
leaf_count = chunks |> get_chunks_len() |> next_pow_of_two()
121-
root = merkleize_chunks(chunks, leaf_count)
158+
root = merkleize_chunks_with_virtual_padding(chunks, leaf_count)
122159
{:ok, root}
123160
end
124161

@@ -164,6 +201,38 @@ defmodule LambdaEthereumConsensus.SszEx do
164201
end
165202
end
166203

204+
def merkleize_chunks_with_virtual_padding(chunks, leaf_count) do
205+
chunks_len = chunks |> get_chunks_len()
206+
power = leaf_count |> compute_pow()
207+
height = power + 1
208+
209+
cond do
210+
chunks_len == 0 ->
211+
depth = height - 1
212+
get_zero_hash(depth)
213+
214+
chunks_len == 1 and leaf_count == 1 ->
215+
chunks
216+
217+
true ->
218+
power = leaf_count |> compute_pow()
219+
height = power + 1
220+
layers = chunks
221+
last_index = chunks_len - 1
222+
223+
{_, final_layer} =
224+
1..(height - 1)
225+
|> Enum.reverse()
226+
|> Enum.reduce({last_index, layers}, fn i, {acc_last_index, acc_layers} ->
227+
updated_layers = update_layers(i, height, acc_layers, acc_last_index)
228+
{acc_last_index |> div(2), updated_layers}
229+
end)
230+
231+
<<root::binary-size(@bytes_per_chunk), _::binary>> = final_layer
232+
root
233+
end
234+
end
235+
167236
@spec pack(boolean, :bool) :: binary()
168237
def pack(true, :bool), do: <<1::@bits_per_chunk-little>>
169238
def pack(false, :bool), do: @zero_chunk
@@ -173,6 +242,11 @@ defmodule LambdaEthereumConsensus.SszEx do
173242
<<value::size(size)-little>> |> pack_bytes()
174243
end
175244

245+
@spec pack(binary, {:bytes, non_neg_integer}) :: binary()
246+
def pack(value, {:bytes, _size}) do
247+
value |> pack_bytes()
248+
end
249+
176250
@spec pack(list(), {:list | :vector, any, non_neg_integer}) :: binary() | :error
177251
def pack(list, {type, schema, _}) when type in [:vector, :list] do
178252
if variable_size?(schema) do
@@ -184,6 +258,11 @@ defmodule LambdaEthereumConsensus.SszEx do
184258
end
185259
end
186260

261+
def chunk_count({:list, type, max_size}) do
262+
size = size_of(type)
263+
(max_size * size + 31) |> div(32)
264+
end
265+
187266
#################
188267
### Private functions
189268
#################
@@ -592,16 +671,6 @@ defmodule LambdaEthereumConsensus.SszEx do
592671

593672
defp size_of({:int, size}), do: size |> div(@bits_per_byte)
594673

595-
defp chunk_count({:list, {:int, size}, max_size}) do
596-
size = size_of({:int, size})
597-
(max_size * size + 31) |> div(32)
598-
end
599-
600-
defp chunk_count({:list, :bool, max_size}) do
601-
size = size_of(:bool)
602-
(max_size * size + 31) |> div(32)
603-
end
604-
605674
defp pack_basic_type_list(list, schema) do
606675
list
607676
|> Enum.reduce(<<>>, fn x, acc ->
@@ -635,12 +704,69 @@ defmodule LambdaEthereumConsensus.SszEx do
635704
end
636705
end
637706

638-
defp next_pow_of_two(len) when is_integer(len) and len >= 0 do
639-
if len == 0 do
640-
0
641-
else
642-
n = ((len <<< 1) - 1) |> :math.log2() |> trunc()
643-
2 ** n
707+
defp next_pow_of_two(0), do: 0
708+
709+
defp next_pow_of_two(len) when is_integer(len) and len > 0 do
710+
n = ((len <<< 1) - 1) |> compute_pow()
711+
2 ** n
712+
end
713+
714+
defp get_chunks_len(chunks) do
715+
chunks |> byte_size() |> div(@bytes_per_chunk)
716+
end
717+
718+
defp compute_pow(value) do
719+
:math.log2(value) |> trunc()
720+
end
721+
722+
defp update_layers(i, height, acc_layers, acc_last_index) do
723+
0..(2 ** i - 1)
724+
|> Enum.filter(fn x -> rem(x, 2) == 0 end)
725+
|> Enum.reduce_while(acc_layers, fn j, acc_layers ->
726+
parent_index = j |> div(2)
727+
nodes = get_nodes(parent_index, i, j, height, acc_layers, acc_last_index)
728+
hash_nodes_and_replace(nodes, acc_layers)
729+
end)
730+
end
731+
732+
defp get_nodes(parent_index, _i, j, _height, acc_layers, acc_last_index)
733+
when j < acc_last_index do
734+
start = parent_index * @bytes_per_chunk
735+
stop = (j + 2) * @bytes_per_chunk
736+
focus = acc_layers |> :binary.part(start, stop - start)
737+
focus_len = focus |> byte_size()
738+
children_index = focus_len - 2 * @bytes_per_chunk
739+
<<_::binary-size(children_index), children::binary>> = focus
740+
741+
<<left::binary-size(@bytes_per_chunk), right::binary-size(@bytes_per_chunk)>> =
742+
children
743+
744+
{children_index, left, right}
745+
end
746+
747+
defp get_nodes(parent_index, i, j, height, acc_layers, acc_last_index)
748+
when j == acc_last_index do
749+
start = parent_index * @bytes_per_chunk
750+
stop = (j + 1) * @bytes_per_chunk
751+
focus = acc_layers |> :binary.part(start, stop - start)
752+
focus_len = focus |> byte_size()
753+
children_index = focus_len - @bytes_per_chunk
754+
<<_::binary-size(children_index), left::binary>> = focus
755+
depth = height - i - 1
756+
right = get_zero_hash(depth)
757+
{children_index, left, right}
758+
end
759+
760+
defp get_nodes(_, _, _, _, _, _), do: :error
761+
762+
defp hash_nodes_and_replace(nodes, layers) do
763+
case nodes do
764+
:error ->
765+
{:halt, layers}
766+
767+
{index, left, right} ->
768+
hash = hash_nodes(left, right)
769+
{:cont, replace_chunk(layers, index, hash)}
644770
end
645771
end
646772

@@ -651,7 +777,9 @@ defmodule LambdaEthereumConsensus.SszEx do
651777
<<left::binary, new_chunk::binary, right::binary>>
652778
end
653779

654-
defp get_chunks_len(chunks) do
655-
chunks |> byte_size() |> div(@bytes_per_chunk)
780+
defp get_zero_hash(depth) do
781+
offset = (depth + 1) * @bytes_per_chunk - @bytes_per_chunk
782+
<<_::binary-size(offset), hash::binary-size(@bytes_per_chunk), _::binary>> = @zero_hashes
783+
hash
656784
end
657785
end

lib/utils/zero_hashes.ex

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
defmodule LambdaEthereumConsensus.Utils.ZeroHashes do
2+
@moduledoc """
3+
Precomputed zero hashes
4+
"""
5+
6+
@bits_per_byte 8
7+
@bytes_per_chunk 32
8+
@bits_per_chunk @bytes_per_chunk * @bits_per_byte
9+
@max_merkle_tree_depth 64
10+
11+
def compute_zero_hashes do
12+
buffer = <<0::size(@bytes_per_chunk * @max_merkle_tree_depth * @bits_per_byte)>>
13+
14+
0..(@max_merkle_tree_depth - 2)
15+
|> Enum.reduce(buffer, fn index, acc_buffer ->
16+
start = index * @bytes_per_chunk
17+
stop = (index + 2) * @bytes_per_chunk
18+
focus = acc_buffer |> :binary.part(start, stop - start)
19+
<<left::binary-size(@bytes_per_chunk), _::binary>> = focus
20+
hash = hash_nodes(left, left)
21+
change_index = (index + 1) * @bytes_per_chunk
22+
replace_chunk(acc_buffer, change_index, hash)
23+
end)
24+
end
25+
26+
defp hash_nodes(left, right), do: :crypto.hash(:sha256, left <> right)
27+
28+
defp replace_chunk(chunks, start, new_chunk) do
29+
<<left::binary-size(start), _::size(@bits_per_chunk), right::binary>> =
30+
chunks
31+
32+
<<left::binary, new_chunk::binary, right::binary>>
33+
end
34+
end

test/spec/runners/ssz_generic.ex

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,15 @@ defmodule SszGenericTestRunner do
6969
{:container, module},
7070
real_serialized,
7171
real_deserialized,
72-
_hash_tree_root
72+
expected_hash_tree_root
7373
) do
7474
real_struct = struct!(module, real_deserialized)
7575
{:ok, deserialized} = SszEx.decode(real_serialized, module)
7676
assert deserialized == real_struct
7777
{:ok, serialized} = SszEx.encode(real_struct, module)
7878
assert serialized == real_serialized
79+
actual_hash_tree_root = SszEx.hash_tree_root!(real_struct, module)
80+
assert actual_hash_tree_root == expected_hash_tree_root
7981
end
8082

8183
defp assert_ssz(
@@ -92,7 +94,7 @@ defmodule SszGenericTestRunner do
9294

9395
assert serialized == real_serialized
9496

95-
{:ok, actual_hash_tree_root} = SszEx.hash_tree_root(real_deserialized, schema)
97+
actual_hash_tree_root = SszEx.hash_tree_root!(real_deserialized, schema)
9698

9799
assert actual_hash_tree_root == expected_hash_tree_root
98100
end

0 commit comments

Comments
 (0)