From 71da6a6aa40b227c7f53e6ea769bdff7cb4f343e Mon Sep 17 00:00:00 2001 From: Anders Dalskov Date: Sun, 30 Jul 2023 20:31:36 +0200 Subject: [PATCH] new version --- CMakeLists.txt | 7 +- RELEASE.txt | 46 +- include/scl/math/curves/secp256k1.h | 2 +- include/scl/math/ec.h | 21 +- include/scl/math/ec_ops.h | 2 +- include/scl/math/ff.h | 30 + include/scl/math/ff_ops.h | 2 + include/scl/math/lagrange.h | 16 +- include/scl/math/mat.h | 10 +- include/scl/math/ops_gmp_ff.h | 286 ++++---- include/scl/math/poly.h | 8 + include/scl/math/ring.h | 110 --- include/scl/math/vec.h | 10 +- include/scl/net/channel.h | 3 - include/scl/protocol/base.h | 4 +- include/scl/simulation/buffer.h | 11 +- include/scl/simulation/channel.h | 15 +- include/scl/simulation/channel_id.h | 12 +- include/scl/simulation/config.h | 132 ++-- include/scl/simulation/context.h | 149 +++- include/scl/simulation/env.h | 13 +- include/scl/simulation/event.h | 26 +- include/scl/simulation/manager.h | 149 ++++ include/scl/simulation/measurement.h | 117 --- include/scl/simulation/mem_channel_buffer.h | 35 +- include/scl/simulation/result.h | 30 +- include/scl/simulation/simulator.h | 80 +-- include/scl/ss/feldman.h | 9 +- include/scl/ss/shamir.h | 122 +++- include/scl/util/digest.h | 8 +- include/scl/util/hash.h | 1 + include/scl/util/iuf_hash.h | 17 + include/scl/util/merkle.h | 194 +++++ include/scl/util/sha256.h | 2 +- include/scl/util/sha3.h | 2 +- include/scl/util/sign.h | 57 +- include/scl/util/traits.h | 32 +- src/scl/math/secp256k1_curve.cc | 204 +++--- src/scl/math/secp256k1_field.cc | 58 +- src/scl/math/secp256k1_helpers.h | 8 +- ...secp256k1_order.cc => secp256k1_scalar.cc} | 56 +- src/scl/simulation/channel.cc | 151 ++-- src/scl/simulation/config.cc | 38 +- src/scl/simulation/context.cc | 34 +- src/scl/simulation/event.cc | 75 +- src/scl/simulation/measurement.cc | 75 +- src/scl/simulation/result.cc | 176 ++--- src/scl/simulation/simulate_recv_time.cc | 27 +- src/scl/simulation/simulator.cc | 188 ++--- test/scl/math/fields.h | 2 +- test/scl/math/test_ff.cc | 16 + test/scl/math/test_poly.cc | 1 + test/scl/math/test_secp256k1.cc | 21 +- test/scl/math/test_vec.cc | 25 + test/scl/simulation/test_channel.cc | 200 ++++++ test/scl/simulation/test_config.cc | 100 ++- test/scl/simulation/test_context.cc | 68 +- test/scl/simulation/test_env.cc | 26 +- test/scl/simulation/test_event.cc | 5 + test/scl/simulation/test_manager.cc | 58 ++ test/scl/simulation/test_measurement.cc | 25 +- .../scl/simulation/test_mem_channel_buffer.cc | 32 +- test/scl/simulation/test_result.cc | 46 +- test/scl/simulation/test_simulator.cc | 676 +++++------------- test/scl/ss/test_feldman.cc | 2 +- test/scl/ss/test_shamir.cc | 46 ++ test/scl/util/test_merkle.cc | 88 +++ test/scl/util/test_sha3.cc | 18 +- 68 files changed, 2595 insertions(+), 1720 deletions(-) delete mode 100644 include/scl/math/ring.h create mode 100644 include/scl/simulation/manager.h create mode 100644 include/scl/util/merkle.h rename src/scl/math/{secp256k1_order.cc => secp256k1_scalar.cc} (75%) create mode 100644 test/scl/simulation/test_channel.cc create mode 100644 test/scl/simulation/test_manager.cc create mode 100644 test/scl/util/test_merkle.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e4e9c8..b8866c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required( VERSION 3.14 ) -project( scl VERSION 6.2.0 DESCRIPTION "Secure Computation Library" ) +project( scl VERSION 0.7.0 DESCRIPTION "Secure Computation Library" ) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) @@ -64,7 +64,7 @@ if(WITH_EC MATCHES ON) src/scl/math/ops_gmp_ff.cc src/scl/math/secp256k1_field.cc src/scl/math/secp256k1_curve.cc - src/scl/math/secp256k1_order.cc + src/scl/math/secp256k1_scalar.cc src/scl/math/number.cc) endif() @@ -102,6 +102,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") test/scl/util/test_sha256.cc test/scl/util/test_ecdsa.cc test/scl/util/test_cmdline.cc + test/scl/util/test_merkle.cc test/scl/gf7.cc test/scl/math/test_mersenne61.cc @@ -137,7 +138,9 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") test/scl/simulation/test_result.cc test/scl/simulation/test_measurement.cc test/scl/simulation/test_mem_channel_buffer.cc + test/scl/simulation/test_channel.cc test/scl/simulation/test_env.cc + test/scl/simulation/test_manager.cc test/scl/serialization/test_serializer.cc) diff --git a/RELEASE.txt b/RELEASE.txt index cb2d2b1..c3928d3 100644 --- a/RELEASE.txt +++ b/RELEASE.txt @@ -1,14 +1,30 @@ -6.2.0: More functionality for Number +0.7.0: +- Exponentiation for field elements +- Various bug fixes. Especially in the simulation code +- Change versioning. Make all releases start with 0 (to mark them as pre-release). +- Merkle tree hashing. +- Make it possible to hash anything which has a Serializer specialization. +- Vec::ScalarMultiply now allows multiplying a Vec of curve points with a + scalar. Same for Mat. +- Make it possible to prematurely terminate a party in a simulation. +- Introduce a "Manager" class that contains the parameters of a simulation. +- Rename EC::Order to EC::ScalarField. +- Introduce a function for acquiring the order of a field. +- Make utility functions in ECDSA public. +- Various optimizations for the elliptic curve code. +- Simplify the measurement class. + +0.6.2: More functionality for Number - Add modulo operator to Number. - Add some mathematical functions that operate on numbers. - Make Number serializable; add Serializer specialization. - Add a simple command-line argument parser. -6.1.0: Extend serialization functionality +0.6.1: Extend serialization functionality - Make Write methods return the number of bytes written. - Make it possible to serialize vectors with arbitrary content. -6.0.0: Improvements to serialization and Channels. +0.6.0: Improvements to serialization and Channels. - Added a Serializer type that can be specialized in order to specify how various objects are converted to bytes. - Added a Packet type that allows reading and writing almost arbitrary objects, @@ -17,37 +33,33 @@ Packets. Remove old Send/Recv overloads. - Remove proto::ProtocolEnvironment. -5.3.0: ECDSA +0.5.3: ECDSA - Added functionality for creating ECDSA signatures. -5.2.0: Protocol environment extensions +0.5.2: Protocol environment extensions - Make it possible to create "checkpoints" through the protocol environment clock. - fix a bug that prevented the documentation from being buildt - Rename ProtocolEnvironment to Env, and introduce a typedef for backwards compatability. -5.1.2: Style changes +0.5.1: Style changes - Change naming style of private field members. - -5.1.1: Bug fixes and simplifications - Simplifed the NextToRun logic because a greedy strategy too often results in rollbacks. - Fixed a bug in the Rollback logic where WriteOps weren't rolled back correctly. - -5.1: Vec-Mat multiplication - Add a Vec Mat to Vec multiplication function to Mat - Minor refactoring of test_mat.cc -5.0: Simulation +0.5.0: Simulation - Added a new module for simulating protocol executions under different network conditions. - Refactored layout with respect to namespaces. details no longer exists, and the different modules have gotten their own namespace. - Up test coverage to 100%. Minor refactoring to the actions. -4.0: Shamir, Feldman, SHA-256 +0.4.0: Shamir, Feldman, SHA-256 - Refactor Shamir to allow caching of Lagrange coefficients - Add support for Feldman Secret Sharing - Add support for SHA-256 @@ -58,7 +70,7 @@ - Fix negation of 0 in Secp256k1::Field and Secp256k1::Order - Make serialization and deserialization of curve points behave more sanely -3.0: More features, build changes +0.3.0: More features, build changes - Add method for returning a point as a pair of affine coordinates - Add method to check if a channel has data available - Allow sending and receiving STL vectors without specifying the size @@ -72,12 +84,12 @@ - disable actions for master branch - add clang-tidy action -2.1: More Finite Fields +0.2.1: More Finite Fields - Provide a FF implementation for computations modulo the order of Secp256k1 - Extend EC with support for scalar multiplications with scalars from a finite field of size the order of a subgroup. -2.0: Elliptic curves and finite field refactoring +0.2.0: Elliptic curves and finite field refactoring - Make it simpler to define new finite fields - Include optional (but enabled by default) support for elliptic curves - Implement secp256k1 @@ -87,13 +99,13 @@ - Rename FF to Fp. - Move class FF into scl namespace. -1.1: Refactoring of finite field internals +0.1.1: Refactoring of finite field internals - Finite field operations are now defined by individual specializations of templated functions - Remove DEFINE_FINITE_FIELD macro - Move Mersenne61 and Mersenne127 definitions into ff.h -1.0: Initial public version of SCL. +0.1.0: Initial public version of SCL. - Features: - Math: - Finite Field class with two instantiations based on Mersenne primes diff --git a/include/scl/math/curves/secp256k1.h b/include/scl/math/curves/secp256k1.h index 0ef9c23..085c944 100644 --- a/include/scl/math/curves/secp256k1.h +++ b/include/scl/math/curves/secp256k1.h @@ -57,7 +57,7 @@ struct Secp256k1 { /** * @brief Finite field modulo a Secp256k1 prime order sub-group. */ - struct Order { + struct Scalar { /** * @brief Internal type of elements. */ diff --git a/include/scl/math/ec.h b/include/scl/math/ec.h index b45be5c..9306872 100644 --- a/include/scl/math/ec.h +++ b/include/scl/math/ec.h @@ -47,7 +47,7 @@ class EC final : Add>, Eq>, Print> { /** * @brief A large sub-group of this curve. */ - using Order = FF; + using ScalarField = FF; /** * @brief The size of a curve point in bytes. @@ -165,7 +165,7 @@ class EC final : Add>, Eq>, Print> { * @param scalar the scalar * @return this. */ - EC& operator*=(const Order& scalar) { + EC& operator*=(const ScalarField& scalar) { CurveScalarMultiply(m_value, scalar); return *this; } @@ -187,7 +187,7 @@ class EC final : Add>, Eq>, Print> { * @param scalar the scalar * @return the point multiplied with the scalar. */ - friend EC operator*(const EC& point, const Order& scalar) { + friend EC operator*(const EC& point, const ScalarField& scalar) { EC copy(point); return copy *= scalar; } @@ -208,10 +208,9 @@ class EC final : Add>, Eq>, Print> { * @param scalar the scalar * @return the point multiplied with the scalar. */ - friend EC operator*(const FF& scalar, - const EC& point) { + friend EC operator*(const ScalarField& scalar, const EC& point) { return point * scalar; - } + } // LCOV_EXCL_LINE /** * @brief Negate this point. @@ -229,7 +228,7 @@ class EC final : Add>, Eq>, Print> { */ bool Equal(const EC& other) const { return CurveEqual(m_value, other.m_value); - } + } // LCOV_EXCL_LINE /** * @brief Check if this point is equal to the point at inifity. @@ -237,7 +236,7 @@ class EC final : Add>, Eq>, Print> { */ bool PointAtInfinity() const { return CurveIsPointAtInfinity(m_value); - } + } // LCOV_EXCL_LINE /** * @brief Return this point as a pair of affine coordinates. @@ -245,14 +244,14 @@ class EC final : Add>, Eq>, Print> { */ std::array ToAffine() const { return CurveToAffine(m_value); - } + } // LCOV_EXCL_LINE /** * @brief Output this point as a string. */ std::string ToString() const { return CurveToString(m_value); - } + } // LCOV_EXCL_LINE /** * @brief Write this point to a buffer. @@ -261,7 +260,7 @@ class EC final : Add>, Eq>, Print> { */ void Write(unsigned char* dest, bool compress = true) const { CurveToBytes(dest, m_value, compress); - } + } // LCOV_EXCL_LINE private: typename Curve::ValueType m_value; diff --git a/include/scl/math/ec_ops.h b/include/scl/math/ec_ops.h index 5e9ab3a..6adabb0 100644 --- a/include/scl/math/ec_ops.h +++ b/include/scl/math/ec_ops.h @@ -110,7 +110,7 @@ void CurveScalarMultiply(typename C::ValueType& out, const Number& scalar); */ template void CurveScalarMultiply(typename C::ValueType& out, - const FF& scalar); + const FF& scalar); /** * @brief Check if two elliptic curve points are equal. diff --git a/include/scl/math/ff.h b/include/scl/math/ff.h index 560d1da..c7ed9f6 100644 --- a/include/scl/math/ff.h +++ b/include/scl/math/ff.h @@ -255,6 +255,36 @@ class FF final : Add>, friend class FFAccess; }; +/** + * @brief Returns the order of a finite field. + */ +template +Number Order(); + +/** + * @brief Raise an element to a power. + * @param base the base. + * @param exp the exponent. + * @return \p base raised to the \p exp th power. + */ +template +FF Exp(const FF& base, std::size_t exp) { + if (exp == 0) { + return FF::One(); + } + + const auto n = sizeof(std::size_t) * 8 - __builtin_clzll(exp); + FF r = FF::One(); + for (std::size_t i = n; i-- > 0;) { + r *= r; + if (((exp >> i) & 1) == 1) { + r *= base; + } + } + + return r; +} + } // namespace scl::math #endif // SCL_MATH_FF_H diff --git a/include/scl/math/ff_ops.h b/include/scl/math/ff_ops.h index e4c4f5e..8b81ee0 100644 --- a/include/scl/math/ff_ops.h +++ b/include/scl/math/ff_ops.h @@ -23,6 +23,8 @@ #include #include +#include "scl/math/number.h" + namespace scl::math { /** diff --git a/include/scl/math/lagrange.h b/include/scl/math/lagrange.h index a908b68..ec8c394 100644 --- a/include/scl/math/lagrange.h +++ b/include/scl/math/lagrange.h @@ -52,8 +52,7 @@ namespace scl::math { * @see https://en.wikipedia.org/wiki/Lagrange_polynomial */ template -Vec ComputeLagrangeBasis(const math::Vec& nodes, int x) { - const auto _x = T{x}; +Vec ComputeLagrangeBasis(const math::Vec& nodes, const T& x) { const auto n = nodes.Size(); std::vector b; b.reserve(n); @@ -63,7 +62,7 @@ Vec ComputeLagrangeBasis(const math::Vec& nodes, int x) { for (std::size_t j = 0; j < n; ++j) { if (i != j) { const auto xj = nodes[j]; - ell *= (_x - xj) / (xi - xj); + ell *= (x - xj) / (xi - xj); } } b.emplace_back(ell); @@ -71,6 +70,17 @@ Vec ComputeLagrangeBasis(const math::Vec& nodes, int x) { return b; } +/** + * @brief Computes a lagrange basis for a set of nodes. + * @param nodes the set of nodes. + * @param x the evaluation point x. + * @see ComputeLagrangeBasis + */ +template +Vec ComputeLagrangeBasis(const math::Vec& nodes, int x) { + return ComputeLagrangeBasis(nodes, T{x}); +} + } // namespace scl::math #endif // SCL_MATH_LAGRANGE_H diff --git a/include/scl/math/mat.h b/include/scl/math/mat.h index a157a5b..e46f4dc 100644 --- a/include/scl/math/mat.h +++ b/include/scl/math/mat.h @@ -332,7 +332,10 @@ class Mat : Print> { * @param scalar the scalar * @return this scaled by \p scalar. */ - Mat ScalarMultiply(const Elem& scalar) const { + template < + typename Scalar, + std::enable_if_t::value, bool> = true> + Mat ScalarMultiply(const Scalar& scalar) const { Mat copy(m_rows, m_cols, m_values); return copy.ScalarMultiplyInPlace(scalar); } @@ -342,7 +345,10 @@ class Mat : Print> { * @param scalar the scalar * @return this scaled by \p scalar. */ - Mat& ScalarMultiplyInPlace(const Elem& scalar) { + template < + typename Scalar, + std::enable_if_t::value, bool> = true> + Mat& ScalarMultiplyInPlace(const Scalar& scalar) { for (auto& v : m_values) { v *= scalar; } diff --git a/include/scl/math/ops_gmp_ff.h b/include/scl/math/ops_gmp_ff.h index 3b6bf43..11aebde 100644 --- a/include/scl/math/ops_gmp_ff.h +++ b/include/scl/math/ops_gmp_ff.h @@ -18,6 +18,7 @@ #ifndef SCL_MATH_OPS_GMP_FF_H #define SCL_MATH_OPS_GMP_FF_H +#include #include #include #include @@ -33,147 +34,175 @@ namespace scl::math { #define SCL_BITS_PER_LIMB static_cast(mp_bits_per_limb) #define SCL_BYTES_PER_LIMB sizeof(mp_limb_t) -#define SCL_COPY(out, in, size) \ - do { \ - for (std::size_t i = 0; i < (size); ++i) { \ - *((out) + i) = *((in) + i); \ - } \ - } while (0) +/** + * @brief Reduction parameters used to perform Montgomery reduction. + * @tparam N the number of words in the parameters. + * + * This struct is used to perform Montgomery modular reductions and is used + * throughout all Monty* functions. + */ +template +struct RedParams { + /** + * @brief The prime. + */ + mp_limb_t prime[N]; + + /** + * @brief A constant used in montgomery reduction. + * + * This constant is computed as \f$mc = -prime^{-1} \mod 2^{w * N}\f$ where + * \f$w\f$ is the word size in bits (probably 64). + */ + mp_limb_t mc[N]; +}; /** * @brief Convert a value into montgomery form mod some prime. - * @tparam N the size of the input - * @param value the value to convert - * @param mod the prime + * @tparam N the number of limbs in the value to convert. + * @param out the value to convert. + * @param rp reduction parameters. */ template -void MontyIn(mp_limb_t* value, const mp_limb_t* mod) { +void MontyIn(mp_limb_t* out, const RedParams rp) { mp_limb_t qp[N + 1]; mp_limb_t shift[2 * N] = {0}; - // multiply val by 2^{256} - SCL_COPY(shift + N, value, N); - // compute (val * 2^{256}) mod p - mpn_tdiv_qr(qp, value, 0, shift, 2 * N, mod, N); + // multiply val by 2^{w * N} + std::copy(out, out + N, shift + N); + // compute (val * 2^{w * N}) mod p + mpn_tdiv_qr(qp, out, 0, shift, 2 * N, rp.prime, N); } /** * @brief Perform a montgomery reduction. - * @param val the value to reduce - * @param mod the modulus - * @param np a number n such that 2^{N} * a + mod * n == 1 - * @tparam N the size of the input + * @tparam N the number of limbs in the value to reduce. + * @param out the value to reduce. + * @param rp reduction parameters. */ template -void MontyRedc(mp_limb_t* val, const mp_limb_t* mod, const mp_limb_t* np) { - // https://cp-algorithms.com/algebra/montgomery_multiplication.html#montgomery-reduction +void MontyRedc(mp_limb_t* out, const RedParams rp) { + // q = val * rp.mc + // TODO: This can be optimized a bit since q is reduced modulo 2^N below mp_limb_t q[2 * N]; - // TODO: This multiplication can be optimized because we're only interested in - // the result mod r = 2^{N}. - mpn_mul_n(q, val, np, N); + mpn_mul_n(q, out, rp.mc, N); + + // c = (q mod 2^N) * rp.prime mp_limb_t c[2 * N]; - mpn_mul_n(c, q, mod, N); - auto borrow = mpn_sub_n(c, val, c, 2 * N); + mpn_mul_n(c, q, rp.prime, N); - SCL_COPY(val, c + N, N); + // val + c / 2^256 + const auto carry = mpn_add_n(c, out, c, 2 * N); + std::copy(c + N, c + 2 * N, out); - if (borrow) { - mpn_add_n(val, val, mod, N); + if (carry || mpn_cmp(out, rp.prime, N) >= 0) { + mpn_sub_n(out, out, rp.prime, N); } } /** - * @brief Convert an integer into a multi-precision value. - * @param out result - * @param value the int to convert from - * @param mod a modulus + * @brief Convert an integer into a value. + * @tparam N the number of limbs in the output. + * @param out destination of the converted value. + * @param value the int to convert from. + * @param rp reduction parameters. * * This function converts an integer into an \p N limb multi-precision integer * modulo \p mod. The function assumes that \p out has been zeroed. */ template -void MontyInFromInt(mp_limb_t* out, const int value, const mp_limb_t* mod) { +void MontyInFromInt(mp_limb_t* out, const int value, const RedParams rp) { out[0] = std::abs(value); if (value < 0) { - mpn_sub_n(out, mod, out, N); + mpn_sub_n(out, rp.prime, out, N); } - MontyIn(out, mod); + MontyIn(out, rp); } /** * @brief Perform a modular addition. - * @param out the first operand and destination of result - * @param op the second operand - * @param mod the modulus + * @tparam N the number of limbs in the values to add. + * @param out the first operand and destination of result. + * @param op the second operand. + * @param rp reduction parameters. */ template -void MontyModAdd(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod) { +void MontyModAdd(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { auto carry = mpn_add_n(out, out, op, N); - if (carry || mpn_cmp(out, mod, N) >= 0) { - mpn_sub_n(out, out, mod, N); + if (carry || mpn_cmp(out, rp.prime, N) >= 0) { + mpn_sub_n(out, out, rp.prime, N); } } /** * @brief Perform a modular subtraction. - * @param out the first operand and destination of result - * @param op the second operand - * @param mod the modulus. + * @tparam N the number of limbs in the values to subtract. + * @param out the first operand and destination of result. + * @param op the second operand. + * @param rp reduction parameters. */ template -void MontyModSub(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* mod) { +void MontyModSub(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { auto carry = mpn_sub_n(out, out, op, N); if (carry) { - mpn_add_n(out, out, mod, N); + mpn_add_n(out, out, rp.prime, N); } } /** * @brief Perform a modular negation. - * @param out the operand and destination of result - * @param mod the modulus + * @tparam N the number of limbs in the value to negate. + * @param out the operand and destination of result. + * @param rp reduction parameters. */ template -void MontyModNeg(mp_limb_t* out, const mp_limb_t* mod) { +void MontyModNeg(mp_limb_t* out, const RedParams rp) { mp_limb_t t[N] = {0}; - MontyModSub(t, out, mod); - SCL_COPY(out, t, N); + MontyModSub(t, out, rp); + std::copy(t, t + N, out); } /** - * @brief Perform a modular multiplication in montgomery representation - * @param out the first operand and destination of result - * @param op the second operand - * @param mod the modulus - * @param np a constant used for montgomery reduction - * @see MontyRedc + * @brief Multiply two values in Montgomery representation. + * @tparam N the number of limbs in the valus to multiply. + * @param out the first operand and destiantion of result. + * @param op the second operand. + * @param rp reduction parameters. + * + * This function performs an interleaved Montgomery modular + * multiplication. */ template -void MontyModMul(mp_limb_t* out, - const mp_limb_t* op, - const mp_limb_t* mod, - const mp_limb_t* np) { - mp_limb_t res[2 * N]; - mpn_mul_n(res, out, op, N); - MontyRedc(res, mod, np); - SCL_COPY(out, res, N); +void MontyModMul(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { + mp_limb_t u[N + 1] = {0}; + + for (std::size_t i = 0; i < N; ++i) { + const auto c0 = mpn_addmul_1(u, op, N, out[i]); + const auto q = rp.mc[0] * u[0]; + const auto c1 = mpn_addmul_1(u, rp.prime, N, q); + u[N] += c1 + c0; + std::copy(u + 1, u + N + 1, u); + u[N] = ((c1 & c0) | ((c1 | c0) & ~u[N])) >> (SCL_BITS_PER_LIMB - 1); + } + + std::copy(u, u + N, out); + if (u[N] || mpn_cmp(out, rp.prime, N) >= 0) { + mpn_sub_n(out, out, rp.prime, N); + } } /** - * @brief Perform a modular squaring in montgomery representation - * @param out the output - * @param op the operand to square - * @param mod the modulus - * @param np a constant used for montgomery reduction + * @brief Square a value in Montgomery representation. + * @tparam N the number of limbs in the value to square. + * @param out the output. + * @param op the operand to square. + * @param rp reduction parameters. */ template -void MontyModSqr(mp_limb_t* out, - const mp_limb_t* op, - const mp_limb_t* mod, - const mp_limb_t* np) { +void MontyModSqr(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { mp_limb_t res[2 * N]; mpn_sqr(res, op, N); - MontyRedc(res, mod, np); - SCL_COPY(out, res, N); + MontyRedc(res, rp); + std::copy(res, res + N, out); } /** @@ -186,54 +215,53 @@ inline bool TestBit(const mp_limb_t* v, std::size_t pos) { } /** - * @brief Modular exponentation - * @param out output. Must initially be equal to 1 in montgomery form - * @param x the base - * @param e the exponent - * @param mod the modulus - * @param np a constant used for montgomery reduction - * - * This function performs a modular exponentation of a multiprecision integer in - * montgomery form. + * @brief Modular exponentation. + * @tparam N the number of limbs in the base. + * @param out output. Must initially be equal to 1 in montgomery form. + * @param base the base. + * @param exp the exponent. + * @param rp reduction parameters. */ template void MontyModExp(mp_limb_t* out, - const mp_limb_t* x, - const mp_limb_t* e, - const mp_limb_t* mod, - const mp_limb_t* np) { - auto n = mpn_sizeinbase(e, N, 2); + const mp_limb_t* base, + const mp_limb_t* exp, + const RedParams rp) { + auto n = mpn_sizeinbase(exp, N, 2); for (std::size_t i = n; i-- > 0;) { - MontyModSqr(out, out, mod, np); - if (TestBit(e, i)) { - MontyModMul(out, x, mod, np); + MontyModSqr(out, out, rp); + if (TestBit(exp, i)) { + MontyModMul(out, base, rp); } } } /** * @brief Compute a modular inverse. - * @param out output destination - * @param op the value to invert - * @param mod the modulus - * @param mod_minus_2 \p mod minus 2 - * @param np a constant used for montgomery reduction + * @tparam N the number of limbs in the value to invert. + * @param out output destination. + * @param op the value to invert. + * @param prime_minus_2 \p rp.prime minus 2. + * @param rp reduction parameters. + * + * This function computes a modular inverse using Fermats little thereom. The \p + * prime_minus_2 argument is assumed to be \f$rp.prime - 2\f$. */ template void MontyModInv(mp_limb_t* out, const mp_limb_t* op, - const mp_limb_t* mod, - const mp_limb_t* mod_minus_2, - const mp_limb_t* np) { + const mp_limb_t* prime_minus_2, + const RedParams rp) { if (mpn_zero_p(op, N)) { throw std::invalid_argument("0 not invertible modulo prime"); } - MontyModExp(out, op, mod_minus_2, mod, np); + MontyModExp(out, op, prime_minus_2, rp); } /** - * @brief Compute a comparison between two values + * @brief Compute a comparison between two values. + * @tparam N the number of limbs in the values to convert. * @return a value x such that R(x, 0) <==> R(lhs, rhs). */ template @@ -243,38 +271,38 @@ int CompareValues(const mp_limb_t* lhs, const mp_limb_t* rhs) { /** * @brief Deserialize a value and convert to montgomery form. - * @param out output destination - * @param src where to read the value from - * @param mod the modulus + * @tparam N the number of limbs in the value to convert. + * @param out output destination. + * @param src where to read the value from. + * @param rp reduction parameters. */ template void MontyFromBytes(mp_limb_t* out, const unsigned char* src, - const mp_limb_t* mod) { + const RedParams rp) { for (int i = N - 1; i >= 0; --i) { for (int j = SCL_BYTES_PER_LIMB - 1; j >= 0; --j) { out[i] |= static_cast(*src++) << (j * 8); } } - MontyIn(out, mod); + MontyIn(out, rp); } /** * @brief Write a value in montgomery form to a buffer. - * @param dest the output buffer - * @param src the input value - * @param mod the modulus - * @param np a montgomery constant + * @tparam N the number of limbs in value to convert. + * @param dest the output buffer. + * @param src the input value. + * @param rp reduction parameters. */ template void MontyToBytes(unsigned char* dest, const mp_limb_t* src, - const mp_limb_t* mod, - const mp_limb_t* np) { + const RedParams rp) { mp_limb_t padded[2 * N] = {0}; - SCL_COPY(padded, src, N); - MontyRedc(padded, mod, np); + std::copy(src, src + N, padded); + MontyRedc(padded, rp); std::size_t c = 0; for (int i = N - 1; i >= 0; --i) { @@ -287,6 +315,7 @@ void MontyToBytes(unsigned char* dest, /** * @brief Find the first non-zero character in a string. + * @return the position of the first non-zero character. * * This method is used handle a string representation of a number with leading * zeros. @@ -294,15 +323,17 @@ void MontyToBytes(unsigned char* dest, std::size_t FindFirstNonZero(const std::string& s); /** - * @brief Print a value. + * @brief Convert a value in Montgomery representation to a string. + * @tparam N the number of limbs in the value to convert. + * @param val the value to convert. + * @param rp reduction parameters used to convert \p val out of Montgomery form. + * @return \p val as a string. */ template -std::string MontyToString(const mp_limb_t* val, - const mp_limb_t* mod, - const mp_limb_t* np) { +std::string MontyToString(const mp_limb_t* val, const RedParams rp) { mp_limb_t padded[2 * N] = {0}; - SCL_COPY(padded, val, N); - MontyRedc(padded, mod, np); + std::copy(val, val + N, padded); + MontyRedc(padded, rp); static const char* kHexChars = "0123456789abcdef"; std::stringstream ss; @@ -331,11 +362,15 @@ std::string MontyToString(const mp_limb_t* val, /** * @brief Read a value from a string. + * @tparam N the number of limbs in the value to convert. + * @param out the output destination. + * @param str the string to read the output from. + * @param rp reduction parameters used to convert out into Montgomery form. */ template void MontyFromString(mp_limb_t* out, - const mp_limb_t* mod, - const std::string& str) { + const std::string& str, + const RedParams rp) { if (str.length()) { auto n_ = str.length(); if (n_ > 64) { @@ -358,13 +393,12 @@ void MontyFromString(mp_limb_t* out, out[c--] = util::FromHexString(std::string(beg + i, beg + end)); } - MontyIn(out, mod); + MontyIn(out, rp); } } #undef SCL_BITS_PER_LIMB #undef SCL_BYTES_PER_LIMB -#undef SCL_COPY } // namespace scl::math diff --git a/include/scl/math/poly.h b/include/scl/math/poly.h index 6ea2104..a36418c 100644 --- a/include/scl/math/poly.h +++ b/include/scl/math/poly.h @@ -77,6 +77,14 @@ class Polynomial { return m_coefficients[idx]; } + /** + * @brief Get the coefficients of this polynomial. + * @return the coefficients. + */ + Vec Coefficients() const { + return m_coefficients; + } + /** * @brief Add two polynomials. */ diff --git a/include/scl/math/ring.h b/include/scl/math/ring.h deleted file mode 100644 index 1b2765e..0000000 --- a/include/scl/math/ring.h +++ /dev/null @@ -1,110 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_MATH_RING_H -#define SCL_MATH_RING_H - -#include -#include - -namespace scl::math { - -/** - * @brief Derives some basic operations on Ring elements via. CRTP. - * @deprecated - */ -template -struct Ring { - /** - * @brief Add two elements and return their sum. - */ - friend T operator+(const T& lhs, const T& rhs) { - T temp(lhs); - return temp += rhs; - }; - - /** - * @brief Subtract two elements and return their difference. - */ - friend T operator-(const T& lhs, const T& rhs) { - T temp(lhs); - return temp -= rhs; - }; - - /** - * @brief Return the negation of an element. - */ - friend T operator-(const T& elem) { - T temp(elem); - return temp.Negate(); - }; - - /** - * @brief Multiply two elements and return their product. - */ - friend T operator*(const T& lhs, const T& rhs) { - T temp(lhs); - return temp *= rhs; - }; - - /** - * @brief Divide two elements and return their quotient. - */ - friend T operator/(const T& lhs, const T& rhs) { - T temp(lhs); - return temp /= rhs; - }; - - /** - * @brief Compare two elements for equality. - */ - friend bool operator==(const T& lhs, const T& rhs) { - return lhs.Equal(rhs); - }; - - /** - * @brief Compare two elements for inequality. - */ - friend bool operator!=(const T& lhs, const T& rhs) { - return !(lhs == rhs); - }; - - /** - * @brief Write a string representation of an element to a stream. - */ - friend std::ostream& operator<<(std::ostream& os, const T& r) { - return os << r.ToString(); - }; -}; - -/** - * @brief Use to ensure a template parameter is a ring. - * - * enable_if_ring will check if its first parameter is a ring (i.e., it inherits - * from RingElement) and if so, set type to be V. Otherwise it - * fails to compile. - */ -template -struct EnableIfRing { - //! type when T is a ring. - using Type = - typename std::enable_if, T>::value, V>::type; -}; - -} // namespace scl::math - -#endif // SCL_MATH_RING_H diff --git a/include/scl/math/vec.h b/include/scl/math/vec.h index 2c72c12..033abfd 100644 --- a/include/scl/math/vec.h +++ b/include/scl/math/vec.h @@ -274,7 +274,10 @@ class Vec { * @param scalar the scalar * @return a scaled version of this vector. */ - Vec ScalarMultiply(const Elem& scalar) const { + template < + typename Scalar, + std::enable_if_t::value, bool> = true> + Vec ScalarMultiply(const Scalar& scalar) const { std::vector r; r.reserve(Size()); for (const auto& v : m_values) { @@ -288,7 +291,10 @@ class Vec { * @param scalar the scalar * @return a scaled version of this vector. */ - Vec& ScalarMultiplyInPlace(const Elem& scalar) { + template < + typename Scalar, + std::enable_if_t::value, bool> = true> + Vec& ScalarMultiplyInPlace(const Scalar& scalar) { for (auto& v : m_values) { v *= scalar; } diff --git a/include/scl/net/channel.h b/include/scl/net/channel.h index 5ab9f22..9d606f4 100644 --- a/include/scl/net/channel.h +++ b/include/scl/net/channel.h @@ -93,9 +93,6 @@ class Channel { virtual std::optional Recv(bool block = true); }; -#undef SCL_C -#undef SCL_CC - } // namespace scl::net #endif // SCL_NET_CHANNEL_H diff --git a/include/scl/protocol/base.h b/include/scl/protocol/base.h index ef1cad5..721b8cb 100644 --- a/include/scl/protocol/base.h +++ b/include/scl/protocol/base.h @@ -48,7 +48,7 @@ struct Protocol { /** * @brief Default protocol name. */ - constexpr static const char* kDefaultName = "UNNAMED"; + constexpr static const char* DEFAULT_NAME = "UNNAMED"; virtual ~Protocol(){}; /** @@ -67,7 +67,7 @@ struct Protocol { * other. The default value is Protocol::kDefaultName. */ virtual std::string Name() const { - return Protocol::kDefaultName; + return Protocol::DEFAULT_NAME; } /** diff --git a/include/scl/simulation/buffer.h b/include/scl/simulation/buffer.h index d6e9a81..50ee758 100644 --- a/include/scl/simulation/buffer.h +++ b/include/scl/simulation/buffer.h @@ -36,16 +36,17 @@ struct ChannelBuffer { /** * @brief Read data from the channel. - * @param n the number of bytes to read - * @return the data. + * @param data the data to write. + * @param n the number of bytes to read. */ - virtual std::vector Read(std::size_t n) = 0; + virtual void Read(unsigned char* data, std::size_t n) = 0; /** * @brief Write data to the channel. - * @param data the data to write + * @param data the data to write. + * @param n the number of bytes to write. */ - virtual void Write(const std::vector& data) = 0; + virtual void Write(const unsigned char* data, std::size_t n) = 0; /** * @brief Get the amount of bytes that can be read from this channel. diff --git a/include/scl/simulation/channel.h b/include/scl/simulation/channel.h index c79624b..372a287 100644 --- a/include/scl/simulation/channel.h +++ b/include/scl/simulation/channel.h @@ -35,7 +35,7 @@ namespace scl::sim { * This function simply generates a CLOSE event for the current * time of the running party. */ -std::shared_ptr SimulateClose(std::shared_ptr ctx, +std::shared_ptr SimulateClose(std::shared_ptr ctx, ChannelId id); /** @@ -54,7 +54,7 @@ std::shared_ptr SimulateClose(std::shared_ptr ctx, * records a write operation on the context with the time in the * SEND event for the number \p n of bytes sent. */ -std::shared_ptr SimulateSend(std::shared_ptr ctx, +std::shared_ptr SimulateSend(std::shared_ptr ctx, ChannelId id, const unsigned char* src, std::size_t n); @@ -76,7 +76,7 @@ std::shared_ptr SimulateSend(std::shared_ptr ctx, *

The time in the RECV event is adjusted by going through the * recorded write operations for the sending channel */ -std::shared_ptr SimulateRecv(std::shared_ptr ctx, +std::shared_ptr SimulateRecv(std::shared_ptr ctx, ChannelId id, unsigned char* dst, std::size_t n); @@ -98,7 +98,7 @@ std::shared_ptr SimulateRecv(std::shared_ptr ctx, * possible to determine if there are data available. */ std::pair> SimulateHasData( - std::shared_ptr ctx, + std::shared_ptr ctx, ChannelId id); /** @@ -109,15 +109,14 @@ std::pair> SimulateHasData( * sim::SimulateHasData, which performs the actual simulation of the methods in * the Channel interface. */ -class SimulatedChannel final : public net::Channel { +class Channel final : public net::Channel { public: /** * @brief Construct a new Channel for simulations. * @param id the ID of the channel * @param ctx a simulation context object */ - SimulatedChannel(ChannelId id, std::shared_ptr ctx) - : m_id(id), m_ctx(ctx){}; + Channel(ChannelId id, std::shared_ptr ctx) : m_id(id), m_ctx(ctx){}; void Close() override { m_ctx->AddEvent(m_id.local, SimulateClose(m_ctx, m_id)); @@ -144,7 +143,7 @@ class SimulatedChannel final : public net::Channel { private: ChannelId m_id; - std::shared_ptr m_ctx; + std::shared_ptr m_ctx; }; } // namespace scl::sim diff --git a/include/scl/simulation/channel_id.h b/include/scl/simulation/channel_id.h index 5dd734a..85bf449 100644 --- a/include/scl/simulation/channel_id.h +++ b/include/scl/simulation/channel_id.h @@ -20,6 +20,7 @@ #include #include +#include namespace scl::sim { @@ -77,7 +78,14 @@ struct ChannelId { friend bool operator<(const ChannelId& cid0, const ChannelId& cid1) { return cid0.local < cid1.local || (cid0.local == cid1.local && cid0.remote < cid1.remote); - }; + } + + /** + * @brief Print operator for ChannelId. + */ + friend std::ostream& operator<<(std::ostream& os, const ChannelId& cid) { + return os << "ChannelId{" << cid.local << ", " << cid.remote << "}"; + } }; } // namespace scl::sim @@ -87,7 +95,7 @@ struct ChannelId { template <> struct std::hash { std::size_t operator()(const scl::sim::ChannelId& cid) const { - return hash{}(cid.local) ^ (hash{}(cid.remote) << 3); + return cid.local ^ (cid.remote << 32); } }; diff --git a/include/scl/simulation/config.h b/include/scl/simulation/config.h index 6794b93..203c4f4 100644 --- a/include/scl/simulation/config.h +++ b/include/scl/simulation/config.h @@ -29,49 +29,76 @@ namespace scl::sim { /** - * @brief Configuration for the simulated network. + * @brief Configuration for a channel between two parties. */ -class SimulatedNetworkConfig { +class ChannelConfig { public: /** * @brief Builder used to create network configs. */ class Builder; + /** + * @brief Indicates which type of network the channel is emulating. + */ + enum class NetworkType { + /** + * @brief The channel is a TCP channel. + */ + TCP, + + /** + * @brief The channel is a special channel where communication is instant. + */ + INSTANT, + }; + + /** + * @brief Default network type is TCP. + */ + constexpr static NetworkType DEFAULT_NETWORK_TYPE = NetworkType::TCP; + /** * @brief Default bandwidth of the simulated network, in bits/s. */ - constexpr static std::size_t kDefaultBandwidth = 1000000; + constexpr static std::size_t DEFAULT_BANDWIDTH = 1000000; /** * @brief Default RTT of the simulated network in ms. */ - constexpr static std::size_t kDefaultRTT = 100; + constexpr static std::size_t DEFAULT_RTT = 100; /** * @brief Default MSS in bytes. */ - constexpr static std::size_t kDefaultMSS = 1460; + constexpr static std::size_t DEFAULT_MSS = 1460; /** * @brief Default package loss in percentage. */ - constexpr static double kDefaultPackageLoss = 0; + constexpr static double DEFAULT_PACKAGE_LOSS = 0; /** * @brief Default TCP window size in bytes. */ - constexpr static std::size_t kDefaultWindowSize = 65536; + constexpr static std::size_t DEFAULT_WINDOW_SIZE = 65536; /** * @brief Create a simulation config with default values. */ - static SimulatedNetworkConfig Default(); + static ChannelConfig Default(); /** * @brief Create a simulation config for a loopback connection. */ - static SimulatedNetworkConfig Loopback(); + static ChannelConfig Loopback(); + + /** + * @brief The network type of the channel. + */ + NetworkType Type() const { + return m_type; + } /** * @brief Bandwidth in Bits/s. @@ -109,17 +136,20 @@ class SimulatedNetworkConfig { }; private: - SimulatedNetworkConfig(std::size_t bandwidth, - std::size_t rtt, - std::size_t MSS, - double package_loss, - std::size_t window_size) - : m_bandwidth(bandwidth), + ChannelConfig(NetworkType type, + std::size_t bandwidth, + std::size_t rtt, + std::size_t MSS, + double package_loss, + std::size_t window_size) + : m_type(type), + m_bandwidth(bandwidth), m_rtt(rtt), m_MSS(MSS), m_package_loss(package_loss), m_window_size(window_size){}; + NetworkType m_type; std::size_t m_bandwidth; std::size_t m_rtt; std::size_t m_MSS; @@ -130,13 +160,12 @@ class SimulatedNetworkConfig { /** * @brief Pretty print the simulation config. */ -std::ostream& operator<<(std::ostream& os, - const SimulatedNetworkConfig& config); +std::ostream& operator<<(std::ostream& os, const ChannelConfig& config); /** * @brief Builder used to create network configs. */ -class SimulatedNetworkConfig::Builder { +class ChannelConfig::Builder { public: /** * @brief Create an empty simulation config builder. @@ -146,16 +175,27 @@ class SimulatedNetworkConfig::Builder { /** * @brief Build the simulation config. */ - SimulatedNetworkConfig Build() const { + ChannelConfig Build() const { Validate(); - return SimulatedNetworkConfig{ - m_bandwidth.value_or(SimulatedNetworkConfig::kDefaultBandwidth), - m_rtt.value_or(SimulatedNetworkConfig::kDefaultRTT), - m_MSS.value_or(SimulatedNetworkConfig::kDefaultMSS), - m_package_loss.value_or(SimulatedNetworkConfig::kDefaultPackageLoss), - m_window_size.value_or(SimulatedNetworkConfig::kDefaultWindowSize)}; + return ChannelConfig{ + m_type.value_or(ChannelConfig::DEFAULT_NETWORK_TYPE), + m_bandwidth.value_or(ChannelConfig::DEFAULT_BANDWIDTH), + m_rtt.value_or(ChannelConfig::DEFAULT_RTT), + m_MSS.value_or(ChannelConfig::DEFAULT_MSS), + m_package_loss.value_or(ChannelConfig::DEFAULT_PACKAGE_LOSS), + m_window_size.value_or(ChannelConfig::DEFAULT_WINDOW_SIZE)}; }; + /** + * @brief Set the network type of this channel. + * @param type the network type. + * @return the builder. + */ + Builder& Type(NetworkType type) { + m_type = type; + return *this; + } + /** * @brief Set network bandwidth to use for the simulation. * @param bandwidth bandwidth in bits/s. @@ -207,6 +247,7 @@ class SimulatedNetworkConfig::Builder { } private: + std::optional m_type; std::optional m_bandwidth; std::optional m_rtt; std::optional m_MSS; @@ -218,31 +259,32 @@ class SimulatedNetworkConfig::Builder { }; /** - * @brief Creator object for simulation network configs. - * - * The creator should be a function which receives a channel identifier and - * returns network config for that channel. + * @brief Interface describing the network wide configuration. */ -using SimulatedNetworkConfigCreator = - std::function; +struct NetworkConfig { + /** + * @brief Destructor. + */ + virtual ~NetworkConfig() {} + + /** + * @brief Returns the configuration of a particular channel. + */ + virtual ChannelConfig Get(ChannelId channel_id) = 0; +}; /** - * @brief Default config creator implementation. + * @brief Network configuration for a simple network. * - * This implementation returns a default config for all channels; + * SimpleNetworkConfig describes a network where everyone is connected on a + * channel configured according to ChannelConfig::Default. The only exception + * being channels that are self-connecting (i.e., from a party to itself). These + * channels are configured according to ChannelConfig::Loopback. */ -struct DefaultConfigCreator { - /** - * @brief Return a config based on a ChannelId. - * @param channel_id the ChannelId - * - * If the channel_id specifies a channel between two different peers, then - * sim::SimulatedNetworkConfig::Default() is returned, otherwise - * sim::SimulatedNetworkConfig::Loopback() is returned. - */ - SimulatedNetworkConfig operator()(ChannelId channel_id) { - static auto config = SimulatedNetworkConfig::Default(); - static auto lo = SimulatedNetworkConfig::Loopback(); +struct SimpleNetworkConfig final : public NetworkConfig { + ChannelConfig Get(ChannelId channel_id) override { + static auto config = ChannelConfig::Default(); + static auto lo = ChannelConfig::Loopback(); return (channel_id.local == channel_id.remote) ? lo : config; } diff --git a/include/scl/simulation/context.h b/include/scl/simulation/context.h index ba5aeda..785c7fe 100644 --- a/include/scl/simulation/context.h +++ b/include/scl/simulation/context.h @@ -18,10 +18,11 @@ #ifndef SCL_SIMULATION_CONTEXT_H #define SCL_SIMULATION_CONTEXT_H -#include #include #include +#include #include +#include #include "scl/simulation/buffer.h" #include "scl/simulation/channel_id.h" @@ -34,15 +35,27 @@ namespace scl::sim { /** * @brief Context for simulations. */ -class SimulationContext { +class Context { private: enum class State { PREPARE, COMMIT, ROLLBACK }; public: + /** + * @brief Provides a read-only view of a Context. + */ + class View; + /** * @brief A write operation on the channel. */ struct WriteOp { + /** + * @brief Construct a new WriteOp. + * @param amount the amount of data in the write operation. + * @param time the time of the write operation. + */ + WriteOp(std::size_t amount, util::Time::Duration time) + : amount(amount), time(time) {} /** * @brief The amount of data written. */ @@ -66,9 +79,8 @@ class SimulationContext { * implementations that currently exist in SCL. */ template - static std::shared_ptr Create( - std::size_t number_of_parties, - const SimulatedNetworkConfigCreator& config); + static std::shared_ptr Create(std::size_t number_of_parties, + std::shared_ptr config); /** * @brief Construct a new simulation context. @@ -77,16 +89,16 @@ class SimulationContext { * This constructor simply sets the network config for the context but * otherwise performs no initialization whatsoever. Use Create instead. */ - SimulationContext(const SimulatedNetworkConfigCreator& config) - : m_network_conf_creator(config), m_nparties(0) {} + Context(std::shared_ptr config) + : m_network_config(config), m_nparties(0) {} /** - * @brief Get the network config for a particular channel. + * @brief Get the config for a channel. * @param channel_id the ID of the channel * @return a SimulatedNetworkConfig for the channel. */ - SimulatedNetworkConfig NetworkConfig(ChannelId channel_id) const { - return m_network_conf_creator(channel_id); + ChannelConfig ChannelConfiguration(ChannelId channel_id) const { + return m_network_config->Get(channel_id); } /** @@ -106,21 +118,44 @@ class SimulationContext { } /** - * @brief Record a write operation - * @param id the ID of the channel the write was performed on - * @param n the number of bytes written - * @param ts a timestamp indicating when the write took place + * @brief Add a write operation. + * @param id the identifier of the channel that the write occured on. + * @param n the number of bytes written. + * @param time the time the write happened. */ - void RecordWrite(ChannelId id, std::size_t n, util::Time::Duration ts) { - m_writes[id].emplace_back(WriteOp{n, ts}); + void AddWrite(ChannelId id, std::size_t n, util::Time::Duration time) { + m_writes[id].emplace(n, time); } /** - * @brief Get recorded write operations on a particular channel. - * @param id the ID of the channel + * @brief Check if a channel has any unprocessed writes on it. + * @param id the identifier for the channel. + * @return true if the channel has unprocessed writes. False otherwise. */ - std::vector& Writes(ChannelId id) { - return m_writes[id]; + bool HasWrite(ChannelId id) const { + return !(m_writes.find(id) == m_writes.end() || m_writes.at(id).empty()); + } + + /** + * @brief Get the next write on a channel. + * @param id the identifier of the channel. + * @return a write operation. + * + * This method does not check if there are any writes. + */ + WriteOp& NextWrite(ChannelId id) { + return m_writes[id].front(); + } + + /** + * @brief Delete a write operation. + * @param id the identifier of the channel. + * + * This method is meant to be called after a write operation has had all its + * data processed. In a nutshell, when op.amount == 0. + */ + void DeleteWrite(ChannelId id) { + m_writes[id].pop(); } /** @@ -146,6 +181,19 @@ class SimulationContext { return m_traces[id]; } + /** + * @brief Check if a party has terminated. + * @param id the ID of the party. + * @return true if the party has terminated, and otherwise false. + */ + bool HasTerminated(std::size_t id) const { + if (Trace(id).empty()) { + return false; + } + const auto t = Trace(id).back()->EventType(); + return t == sim::Event::Type::STOP || t == sim::Event::Type::KILLED; + } + /** * @brief Remove and return the last event added by a party. */ @@ -210,20 +258,25 @@ class SimulationContext { */ void Rollback(std::size_t id); + /** + * @brief Obtain a View of this context. + */ + View GetView(); + private: - SimulatedNetworkConfigCreator m_network_conf_creator; + std::shared_ptr m_network_config; std::size_t m_nparties; std::vector m_traces; std::size_t m_trace_index; - std::map> m_buffers; + std::unordered_map> m_buffers; State m_state = State::COMMIT; - std::map> m_writes; - std::map> m_writes_backup; + std::unordered_map> m_writes; + std::unordered_map> m_writes_backup; util::Time::TimePoint m_checkpoint; @@ -234,12 +287,52 @@ class SimulationContext { * @brief Create a simulation context with in-memory channels. */ template <> -std::shared_ptr -SimulationContext::Create( +std::shared_ptr Context::Create( std::size_t number_of_parties, - const SimulatedNetworkConfigCreator& config); + std::shared_ptr config); + +/** + * @brief View of a context. + * + * View provides a read-only view of certain parts of the current Context. + */ +class Context::View { + public: + /** + * @brief Get the trace of a party. + * @param id the ID of the party. + */ + SimulationTrace Trace(std::size_t id) const { + return m_ctx.Trace(id); + } + + /** + * @brief Check if a party has terminated. + * @param id the ID of the party. + * @return true if the party has terminated, and otherwise false. + */ + bool HasTerminated(std::size_t id) const { + return m_ctx.HasTerminated(id); + } + + /** + * @brief Get the total number of parties in the simulation. + */ + std::size_t NumberOfParties() const { + return m_ctx.NumberOfParties(); + } + + private: + friend Context; + + View(const Context& ctx) : m_ctx(ctx) {} + + const Context& m_ctx; +}; -// template <> +inline Context::View Context::GetView() { + return Context::View(*this); +} } // namespace scl::sim diff --git a/include/scl/simulation/env.h b/include/scl/simulation/env.h index 7fa391c..ecb87de 100644 --- a/include/scl/simulation/env.h +++ b/include/scl/simulation/env.h @@ -29,15 +29,14 @@ namespace scl::sim { /** * @brief A ProtocolEnvironment::Clock implementation for simulated protocols. */ -class SimulatedClock final : public proto::Env::Clock { +class Clock final : public proto::Env::Clock { public: /** * @brief Create a new clock for simulations. * @param ctx a simulation context. Used to read the current time of the party * @param id the ID of the party */ - SimulatedClock(std::shared_ptr ctx, std::size_t id) - : m_ctx(ctx), m_id(id){}; + Clock(std::shared_ptr ctx, std::size_t id) : m_ctx(ctx), m_id(id){}; /** * @brief Get the total elapsed time of this party. @@ -60,21 +59,21 @@ class SimulatedClock final : public proto::Env::Clock { } private: - std::shared_ptr m_ctx; + std::shared_ptr m_ctx; std::size_t m_id; }; /** * @brief A ProtocolEnvironment::Thread implementation for simulated protocols. */ -class SimulatedThreadCtx final : public proto::Env::Thread { +class ThreadCtx final : public proto::Env::Thread { public: /** * @brief Create a new thread context for simulations. * @param ctx a simulation context * @param id the ID of the party */ - SimulatedThreadCtx(std::shared_ptr ctx, std::size_t id) + ThreadCtx(std::shared_ptr ctx, std::size_t id) : m_ctx(ctx), m_id(id){}; /** @@ -94,7 +93,7 @@ class SimulatedThreadCtx final : public proto::Env::Thread { } private: - std::shared_ptr m_ctx; + std::shared_ptr m_ctx; std::size_t m_id; }; diff --git a/include/scl/simulation/event.h b/include/scl/simulation/event.h index 98aa50c..183704e 100644 --- a/include/scl/simulation/event.h +++ b/include/scl/simulation/event.h @@ -104,7 +104,12 @@ class Event { /** * @brief Event made when a party receives a net::Packet. */ - PACKET_RECV + PACKET_RECV, + + /** + * @brief Event made when a party is stopped prematurely. + */ + KILLED }; /** @@ -337,22 +342,27 @@ class CheckpointEvent final : public Event { /** * @brief Create a new checkpoint event. * @param timestamp the time of the event. - * @param message the message of the checkpoint. + * @param id the id of the checkpoint. */ - CheckpointEvent(util::Time::Duration timestamp, const std::string& message) - : Event(Event::Type::CHECKPOINT, timestamp), m_message(message) {} + CheckpointEvent(util::Time::Duration timestamp, const std::string& id) + : Event(Event::Type::CHECKPOINT, timestamp), m_id(id) {} /** - * @brief Get the checkpoint message. + * @brief Get the checkpoint id. */ - std::string Message() const { - return m_message; + std::string Id() const { + return m_id; } private: - std::string m_message; + std::string m_id; }; +/** + * @brief Pretty print an event type. + */ +std::ostream& operator<<(std::ostream& os, Event::Type type); + /** * @brief Pretty print a measurement to a stream. */ diff --git a/include/scl/simulation/manager.h b/include/scl/simulation/manager.h new file mode 100644 index 0000000..00b1df4 --- /dev/null +++ b/include/scl/simulation/manager.h @@ -0,0 +1,149 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_SIMULATION_MANAGER_H +#define SCL_SIMULATION_MANAGER_H + +#include +#include +#include + +#include "scl/protocol/base.h" +#include "scl/simulation/config.h" +#include "scl/simulation/context.h" + +namespace scl::sim { + +/** + * @brief Manager for a simulation. + * + * The role of a Manager object is to describe the different parameters that + * goes into simulation, such as how the network behaves, how to handle outputs + * and for how many replications to run. + */ +class Manager { + public: + /** + * @brief Construct a new manager. + * @param replications the number of replications to simulate. + */ + Manager(std::size_t replications) : m_replications(replications) {} + + /** + * @brief Destructor. + */ + virtual ~Manager() {} + + /** + * @brief Get the number of replications. + */ + std::size_t Replications() const { + return m_replications; + } + + /** + * @brief Return a fresh instance of the protocol to simulate. + * + * Each simulation replication requires a fresh protocol instance to + * run. This function takes care of returning such a protocol. The simulator + * is assumed to take complete ownership over the returned protocol, so it is + * important that objects returned by this function are independent of objects + * previously returned by calling this function. + */ + virtual std::vector> Protocol() = 0; + + /** + * @brief Handle the output produced by some party. + * @param replication the replication that the output was produced in. + * @param party_id the ID of the party who produced the output. + * @param output the output. + * + * The default implementation simply discards the output. + */ + virtual void HandleOutput(std::size_t replication, + std::size_t party_id, + const std::any& output) { + (void)replication; + (void)party_id; + (void)output; + } + + /** + * @brief Get the configuration for the network. + * + * The default is to return a SimpleNetworkConfig instance. + */ + virtual std::shared_ptr NetworkConfiguration() { + return std::make_shared(); + } + + /** + * @brief Decide whether to terminate a party. + * @param party_id the ID of the party. + * @param view a view of the simulation context. + * + *

Under normal circumstances, a party is terminated when its Run function + * returns nullptr. This function can be used to terminate a + * party prematurely, e.g., after it has been running for a certain amount of + * time. + * + *

The default implementation never terminates parties prematurely. + */ + virtual bool Terminate(std::size_t party_id, const Context::View& view) { + (void)party_id; + (void)view; + return false; + } + + private: + std::size_t m_replications; +}; + +/** + * @brief A simple simulation manager which allows running a protocol once. + */ +class SingleReplicationManager final : public Manager { + public: + /** + * @brief Construct a new SingleReplicationManager. + * @param protocol the protocol to run + */ + SingleReplicationManager( + std::vector> protocol) + : Manager(1), m_protocol(std::move(protocol)), m_used(false) {} + + /** + * @brief Get the protocol to simulate. + * @throws std::logic_error if this function is called more than once. + */ + std::vector> Protocol() { + if (m_used) { + throw std::logic_error( + "Protocol called twice on SingleReplicationManager"); + } + m_used = true; + return std::move(m_protocol); + } + + private: + std::vector> m_protocol; + bool m_used; +}; + +} // namespace scl::sim + +#endif // SCL_SIMULATION_MANAGER_H diff --git a/include/scl/simulation/measurement.h b/include/scl/simulation/measurement.h index dc76b85..2aba9d6 100644 --- a/include/scl/simulation/measurement.h +++ b/include/scl/simulation/measurement.h @@ -29,16 +29,6 @@ namespace scl::sim { /** * @brief Measurement from a simulation. - * - *

A measurement holds the raw samples that are extracted from a protocol, - * but provides several functions that derive useful statistics about these - * samples. Provided statistics are: - * - *

    - *
  • Mean() and Median()
  • - *
  • Min() and Max()
  • - *
  • StdDev() i.e., standard deviation
  • - *
*/ template class Measurement { @@ -49,9 +39,6 @@ class Measurement { */ void AddSample(const T& sample) { m_samples.emplace_back(sample); - // Maybe it makes more sense to use a datastructure here where insertion - // anywhere is constant time? - std::sort(m_samples.begin(), m_samples.end()); } /** @@ -78,114 +65,10 @@ class Measurement { return m_samples.empty(); } - /** - * @brief Mean. - * @return mean of the samples in this measurement. - */ - T Mean() const; - - /** - * @brief Median. - * @return median of the samples in this measurement. - */ - T Median() const; - - /** - * @brief Mininum. - * @return smallest observed sample. - */ - T Min() const { - return Empty() ? Zero() : m_samples[0]; - } - - /** - * @brief Maximum. - * @return largest observed sample. - */ - T Max() const { - return Empty() ? Zero() : m_samples[Size() - 1]; - } - - /** - * @brief Standard deviation. - * @return standard deviation of the samples in this measurement. - */ - T StdDev() const; - private: std::vector m_samples; - - static T Zero() { - return 0; - } - - static T Sqrt(const T& v) { - return std::sqrt(v); - } - - static T Sqr(const T& v) { - return v * v; - } }; -template -T Measurement::Mean() const { - T sum = Zero(); - for (const auto& v : m_samples) { - sum += v; - } - return sum / Size(); -} - -template -T Measurement::Median() const { - if (Empty()) { - return Zero(); - } - - if (Size() == 1) { - return m_samples[0]; - } - - const auto i = Size() / 2; - if (Size() % 2 == 0) { - return m_samples[i]; - } - - return m_samples[i] + m_samples[i + 1]; -} - -template -T Measurement::StdDev() const { - const auto mu = Mean(); - auto sum = Zero(); - for (const auto& v : m_samples) { - sum += Sqr(v - mu); - } - return Sqrt(sum / Size()); -} - -template <> -inline util::Time::Duration Measurement::Zero() { - return util::Time::Duration::zero(); -} - -template <> -inline util::Time::Duration Measurement::Sqrt( - const util::Time::Duration& v) { - long double u = std::sqrt(v.count()); - std::chrono::duration w(u); - return std::chrono::duration_cast(w); -} - -template <> -inline util::Time::Duration Measurement::Sqr( - const util::Time::Duration& v) { - long double u = v.count(); - std::chrono::duration w(u * u); - return std::chrono::duration_cast(w); -} - /** * @brief A measurement for time related observations. * diff --git a/include/scl/simulation/mem_channel_buffer.h b/include/scl/simulation/mem_channel_buffer.h index 3994471..6dc256c 100644 --- a/include/scl/simulation/mem_channel_buffer.h +++ b/include/scl/simulation/mem_channel_buffer.h @@ -36,12 +36,15 @@ namespace scl::sim { * and reads to be rolled back. */ class MemoryBackedChannelBuffer final : public ChannelBuffer { + // type of the internal buffer + using BufferT = std::vector; + public: /** * @brief Create a channel buffer connected to itself. */ static std::shared_ptr CreateLoopback() { - auto buf = std::make_shared>(); + auto buf = std::make_shared(); return std::make_shared(buf, buf); } @@ -49,8 +52,8 @@ class MemoryBackedChannelBuffer final : public ChannelBuffer { * @brief Create a pair of paired channels. */ static std::array, 2> CreatePaired() { - auto buf0 = std::make_shared>(); - auto buf1 = std::make_shared>(); + auto buf0 = std::make_shared(); + auto buf1 = std::make_shared(); return {std::make_shared(buf0, buf1), std::make_shared(buf1, buf0)}; } @@ -60,9 +63,8 @@ class MemoryBackedChannelBuffer final : public ChannelBuffer { * @param write_buffer buffer for storing writes * @param read_buffer buffer for storing reads */ - MemoryBackedChannelBuffer( - std::shared_ptr> write_buffer, - std::shared_ptr> read_buffer) + MemoryBackedChannelBuffer(std::shared_ptr write_buffer, + std::shared_ptr read_buffer) : m_write_buf(write_buffer), m_read_buf(read_buffer), m_write_ptr(0), @@ -74,18 +76,15 @@ class MemoryBackedChannelBuffer final : public ChannelBuffer { return m_read_buf->size() - m_read_ptr; } - std::vector Read(std::size_t n) override { - // silence clang-tidy - auto m = (std::vector::difference_type)m_read_ptr; - auto n_ = (std::vector::difference_type)n; - std::vector data{m_read_buf->begin() + m, - m_read_buf->begin() + m + n_}; + void Read(unsigned char* data, std::size_t n) override { + const auto m = (BufferT::difference_type)m_read_ptr; + const auto n_ = (BufferT::difference_type)n; + std::copy(m_read_buf->begin() + m, m_read_buf->begin() + m + n_, data); m_read_ptr += n; - return data; } - void Write(const std::vector& data) override { - m_write_buf->insert(m_write_buf->end(), data.begin(), data.end()); + void Write(const unsigned char* data, std::size_t n) override { + m_write_buf->insert(m_write_buf->end(), data, data + n); } void Prepare() override { @@ -95,7 +94,7 @@ class MemoryBackedChannelBuffer final : public ChannelBuffer { void Commit() override { // erase the data that was read since Prepare and reset write/read ptr. - auto m = (std::vector::difference_type)m_read_ptr; + auto m = (BufferT::difference_type)m_read_ptr; m_read_buf->erase(m_read_buf->begin(), m_read_buf->begin() + m); m_read_ptr = 0; @@ -109,8 +108,8 @@ class MemoryBackedChannelBuffer final : public ChannelBuffer { } private: - std::shared_ptr> m_write_buf; - std::shared_ptr> m_read_buf; + std::shared_ptr m_write_buf; + std::shared_ptr m_read_buf; std::size_t m_write_ptr; std::size_t m_read_ptr; diff --git a/include/scl/simulation/result.h b/include/scl/simulation/result.h index 4adfc05..4a538ee 100644 --- a/include/scl/simulation/result.h +++ b/include/scl/simulation/result.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -83,7 +84,7 @@ class Result { *

This function is used by Simulate() to create its return value after * running a simulation. The input to this function is a list of traces * traces where traces[i][j] is trace from i'th - * iteration of party j. + * replication of party j. * *

Internally, this function will collect and aggregate all traces created * when simulation a party, and output a Result object for each party. @@ -145,11 +146,11 @@ class Result { /** * @brief Write a trace to a stream. * @param stream the stream to write the trace to. - * @param iteration the simulation iteration + * @param replication the simulation replication * @param name the segment. None if the entire trace should be written. */ void WriteTrace(std::ostream& stream, - std::size_t iteration, + std::size_t replication, const SegmentName& name = {}) const; /** @@ -158,15 +159,35 @@ class Result { */ void Write(std::ostream& stream) const; + /** + * @brief Get the simulation trace from a particular replication. + * @param replication the replication. + * @return the simulation trace from a replication. + */ + SimulationTrace Trace(std::size_t replication) const { + return m_traces[replication]; + } + + /** + * @brief Get the measurement associated with a checkpoint. + * @param key the string identifying the checkpoint. + * @return the time measurement. + */ + TimeMeasurement Checkpoint(const std::string& key) const { + return m_checkpoints.at(key); + } + private: static Result Create(const std::vector& traces); Result( const std::vector& traces, const std::unordered_map& measurements, + const std::unordered_map& checkpoints, const std::vector& segment_names) : m_traces(traces), m_measurements(measurements), + m_checkpoints(checkpoints), m_segment_names(segment_names){}; // The raw simulation trace @@ -175,6 +196,9 @@ class Result { // per-segment measurements std::unordered_map m_measurements; + // user made checkpoints + std::unordered_map m_checkpoints; + // segment names std::vector m_segment_names; }; diff --git a/include/scl/simulation/simulator.h b/include/scl/simulation/simulator.h index 026a1c2..f0c610e 100644 --- a/include/scl/simulation/simulator.h +++ b/include/scl/simulation/simulator.h @@ -24,11 +24,13 @@ #include #include #include +#include #include #include "scl/protocol/base.h" #include "scl/simulation/config.h" #include "scl/simulation/event.h" +#include "scl/simulation/manager.h" #include "scl/simulation/result.h" namespace scl::sim { @@ -59,85 +61,25 @@ struct SimulationFailure final : public std::runtime_error { * number of bytes is specified by the second argument \p n while the network * conditions (bandwidth, latency, overhead, etc...) is specified by \p config. */ -util::Time::Duration ComputeRecvTime(const SimulatedNetworkConfig& config, +util::Time::Duration ComputeRecvTime(const ChannelConfig& config, std::size_t n); /** - * @brief Protocol creator interface. - * - * A ProtocolCreator is a supplier of the protocol that is being simulated. - * Simulations need to be run multiple times in order to get good measurements. - * This interface captures a type whose only job is to return a fresh - * protocol definition every time it is called. - */ -using ProtocolCreator = - std::function>()>; - -/** - * @brief Callback function types when a party creates an output. - * - * Whenever a party produces output during a simulation, a callback of this type - * is called with the party's ID as the first argument, and the output produced - * as the second. - */ -using OutputCallback = std::function; - -/** - * @brief Simulate a protocol execution. - * @param protocol_creator a creator object for the protocol being simulated - * @param config_creator a simulation config creator object - * @param iterations how many iterations the simulation should run for - * @param output_cb a function that is called when a party produce an output + * @brief Simulate the execution of a protocol. + * @param manager a simulation manager. * @return the simulation result. */ -std::vector Simulate( - const ProtocolCreator& protocol_creator, - const SimulatedNetworkConfigCreator& config_creator, - std::size_t iterations, - const OutputCallback& output_cb); +std::vector Simulate(std::unique_ptr manager); /** - * @brief Simulate a protocol execution. - * @param parties the parties of the protocol - * @param config_creator a simulation config creator object - * @param output_cb a function that is called when a party produce an output + * @brief Simulate a protocol for a single replication. + * @param protocol the protocol. * @return the simulation result. */ -std::vector Simulate( - std::vector> parties, - const SimulatedNetworkConfigCreator& config_creator, - const OutputCallback& output_cb); - -/** - * @brief Simulate a protocol execution. - * @param protocol_creator a creator object for the protocol being simulated - * @param config_creator a simulation config creator object - * @param iterations the number of iterations to run the simulation for - */ -inline std::vector Simulate( - const ProtocolCreator& protocol_creator, - const SimulatedNetworkConfigCreator& config_creator, - std::size_t iterations) { - const auto cb = [](auto id, auto output) { - (void)id; - (void)output; - }; - return Simulate(protocol_creator, config_creator, iterations, cb); -} - -/** - * @brief Simulate a protocol execution. - * @param parties the parties of the protocol to simulate - * @param config_creator a simulation config creator object - */ inline std::vector Simulate( - std::vector> parties, - const SimulatedNetworkConfigCreator& config_creator) { - const auto cb = [](auto id, auto output) { - (void)id; - (void)output; - }; - return Simulate(std::move(parties), config_creator, cb); + std::vector> protocol) { + return Simulate( + std::make_unique(std::move(protocol))); } } // namespace scl::sim diff --git a/include/scl/ss/feldman.h b/include/scl/ss/feldman.h index 3b0a964..8ae9a41 100644 --- a/include/scl/ss/feldman.h +++ b/include/scl/ss/feldman.h @@ -38,7 +38,7 @@ struct FeldmanSharing { /** * @brief The shares. */ - math::Vec shares; + math::Vec shares; /** * @brief The commitments. @@ -55,7 +55,7 @@ struct FeldmanSharing { * @return a Feldman secret-sharing. */ template -FeldmanSharing FeldmanShare(const typename G::Order& secret, +FeldmanSharing FeldmanShare(const typename G::ScalarField& secret, std::size_t t, std::size_t n, util::PRG& prg) { @@ -84,7 +84,7 @@ struct ShareAndIndex { /** * @brief The share. */ - typename G::Order share; + typename G::ScalarField share; }; /** @@ -100,7 +100,8 @@ struct ShareAndIndex { template bool FeldmanVerify(const ShareAndIndex& share_and_index, const math::Vec& commits) { - const auto ns = math::Vec::Range(1, commits.Size() + 1); + const auto ns = + math::Vec::Range(1, commits.Size() + 1); const auto lb = math::ComputeLagrangeBasis(ns, share_and_index.index); const auto v = math::UncheckedInnerProd(lb.begin(), lb.end(), commits.begin()); diff --git a/include/scl/ss/shamir.h b/include/scl/ss/shamir.h index 9c424d2..c69b214 100644 --- a/include/scl/ss/shamir.h +++ b/include/scl/ss/shamir.h @@ -36,11 +36,17 @@ namespace scl::ss { /** * @brief Create a Shamir secret-sharing. + * @tparam T a finite field type. * @param secret the secret to secret-share. * @param t the privacy threshold. * @param n the number of shares to output. * @param prg a prg for creating randomness. * @return a Shamir secret-sharing. + * + * This function creates a random polynomial \f$f\f$ of degree \f$t\f$ and such + * that \f$f(0)=\mathtt{secret}\f$. The return value is a list of evaluation + * points (the shares) defined as \f$(f(1), f(2),\dots,f(n))\f$, where the + * points in which \f$f\f$ is evaluated is called the alphas. */ template math::Vec ShamirShare(const T& secret, @@ -63,38 +69,67 @@ math::Vec ShamirShare(const T& secret, /** * @brief Recover a Shamir secret-shared secret. * @param shares the shares. + * @param alphas the alphas. + * @param x the evaluation point. * @return a value. * - * This function interpolates the polynomial \f$f\f$ passing through all of \p - * shares and then returns \f$f(0)\f$. + * This function interpolates a polynomial running through the points \f$(s_i, + * \alpha_i)\f$ where \f$s_i=\mathtt{share}[i]\f$ and + * \f$\alpha_i=\mathtt{alphas}[i]\f$ and returns \f$f(x)\f$. */ template -T ShamirRecoverP(const math::Vec& shares) { - const auto lb = - math::ComputeLagrangeBasis(math::Vec::Range(1, shares.Size() + 1), 0); +T ShamirRecoverP(const math::Vec& shares, + const math::Vec& alphas, + const T& x) { + const auto lb = math::ComputeLagrangeBasis(alphas, x); return math::UncheckedInnerProd(shares.begin(), shares.end(), lb.begin()); } +/** + * @brief Recover a Shamir secret-shared secret. + * @param shares the shares. + * @return a value. + * + * This function is identical to ss::ShamirRecoverP with + * \f$\mathtt{alphas}=(1,2,\dots,\mathtt{shares.size()} + 1)\f$ and + * \f$x=0\f$. It can be used to interpolate (with passive security) a share as + * obtained from ss::ShamirShare. + */ +template +T ShamirRecoverP(const math::Vec& shares) { + return ShamirRecoverP(shares, + math::Vec::Range(1, shares.Size() + 1), + T::Zero()); +} + /** * @brief Recover a Shamir secret-shared secret with error detection. * @param shares the shares. + * @param alphas the alphas. + * @param x the evaluation point. * @return a value. * @throws std::logic_error if the provided shares are not consistent. * - * This function attempts to interpolate a polynomial \f$f\f$ of degree - * \f$t=(\mathtt{shares.size()}-1)/2\f$ that passes through all the provided - * shares. If this succeeds, the \f$f(0)\f$ is returned. Otherwise an exception - * is thrown. + * Let \f$n=\mathtt{shares.size()}\f$ and \f$t=(n-1)/2\f$. This function + * interpolates a polynomial \f$f\f$ running through \f$(s_i,\alpha_i)\f$ where + * \f$s_i=\mathtt{shares}[i]\f$, \f$\alpha_i=\mathtt{alphas}[i]\f$ for + * \f$i=1,\dots,t\f$. Note that this implies that \f$f\f$ has degree + * \f$t\f$. The interpolated polynomial must be consistent with the remaining + * shares and alphas, that is \f$f(\alpha_i)=s_i\f$ for \f$i=t+1,\dots,n\f$. If + * this is the case, then \f$f(x)\f$ is returned, otherwise an + * std::logic_error is thrown. */ template -T ShamirRecoverD(const math::Vec& shares) { +T ShamirRecoverD(const math::Vec& shares, + const math::Vec& alphas, + const T& x) { const std::size_t t = (shares.Size() - 1) / 2; const std::size_t n = 2 * t + 1; - const auto ns = math::Vec::Range(1, t + 2); + const auto ns = alphas.SubVector(t + 1); for (std::size_t i = t + 1; i < n; ++i) { // Shares are indexed starting from 1. - auto lb = math::ComputeLagrangeBasis(ns, i + 1); + auto lb = math::ComputeLagrangeBasis(ns, alphas[i]); auto yi = math::UncheckedInnerProd(shares.begin(), shares.begin() + t + 1, lb.begin()); @@ -103,14 +138,42 @@ T ShamirRecoverD(const math::Vec& shares) { } } - auto lb = math::ComputeLagrangeBasis(ns, 0); + auto lb = math::ComputeLagrangeBasis(ns, x); return math::UncheckedInnerProd(shares.begin(), shares.begin() + t + 1, lb.begin()); } +/** + * @brief Recover a Shamir secret-shared secret with error detection. + * @param shares the shares. + * @return a value. + * + * This function is identical to ss::ShamirRecoverD with + * \f$\mathtt{alphas}=(1,\dots,\mathtt{shares.size()}+1)\f$ and \f$x=0\f$. + */ +template +T ShamirRecoverD(const math::Vec& shares) { + const std::size_t t = (shares.Size() - 1) / 2; + const std::size_t n = 2 * t + 1; + return ShamirRecoverD(shares, math::Vec::Range(1, n + 1), T::Zero()); +} + /** * @brief The result of an error corrected Shamir sharing. + * + *

When recovering a Shamir secret-shared value with error correction, the + * result is either two polynomials or an error, where an error only occurs when + * too many errors are present (i.e., when correction was not possible). + * + *

When correction is possible, the result is a pair \f$(f,e)\f$ where + * \f$f\f$ is the recovered polynomial, and in particular, \f$f(0)\f$ is the + * value that was secret-shared in case the sharing was constructed using + * ss::ShamirShare. The other polynomial \f$e\f$ indicates which shares were + * bad. I.e., \f$e(\alpha_i)=0\f$ says that the evaluation point + * \f$(s_i,\alpha_i)\f$ did not lie on the polynomial \f$f\f$. Usually, + * \f$\alpha_i\f$ is a party identifier, so this is the same as saying that + * party \f$P_{\alpha_i}\f$ sent an invalid share. */ template struct ErrorCorrectedSecret { @@ -128,14 +191,24 @@ struct ErrorCorrectedSecret { /** * @brief Recover a Shamir secret-shared secret with error correction. * @param shares the shares. + * @param alphas the alphas. * @return a pair of polynomials. * @throws std::logic_error if error correction failed. + * + *

Let \f$n=\mathtt{shares.size()}\f$ and \f$t=(n-1)/3\f$. Given a list of + * evaluation points \f$(s_i,\alpha_i)\f$ with \f$s_i=\mathtt{shares}[i]\f$ and + * \f$\alpha_i=\mathtt{alphas}[i]\f$, this function attempts to recover a + * polynomial \f$f\f$ of degree \f$t\f$. If this is possible, the recovered + * polynomial is returned together with a polynomial indicating which supplied + * shares did not lie on the polynomial. + * + *

This function can correct up to \f$t\f$ errors in the supplied shares. */ template -ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares) { +ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares, + const math::Vec& alphas) { const std::size_t t = (shares.Size() - 1) / 3; const std::size_t n = 3 * t + 1; - const auto ns = math::Vec::Range(1, shares.Size() + 1); math::Mat A(n); math::Vec b(n); @@ -148,13 +221,13 @@ ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares) { b[i] = -shares[i]; A(i, 0) = shares[i]; for (int j = 1; j <= e; ++j) { - A(i, j) = A(i, j - 1) * ns[i]; - b[i] *= ns[i]; + A(i, j) = A(i, j - 1) * alphas[i]; + b[i] *= alphas[i]; } A(i, e) = -T(1); for (std::size_t j = e + 1; j < n; ++j) { - A(i, j) = A(i, j - 1) * ns[i]; + A(i, j) = A(i, j - 1) * alphas[i]; } } @@ -177,6 +250,19 @@ ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares) { return {qr[0], E}; } +/** + * @brief Recover a Shamir secret-shared secret with error correction. + * @param shares the shares. + * @return a pair of polynomials. + * + * This function is identical to ss::ShamirRecoverC with + * \f$\mathtt{alphas}=(1,\dots,\mathtt{shares.size()}+1)\f$. + */ +template +ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares) { + return ShamirRecoverC(shares, math::Vec::Range(1, shares.Size() + 1)); +} + } // namespace scl::ss #endif // SCL_SS_SHAMIR_H diff --git a/include/scl/util/digest.h b/include/scl/util/digest.h index 6a69958..0c06356 100644 --- a/include/scl/util/digest.h +++ b/include/scl/util/digest.h @@ -20,7 +20,6 @@ #include #include -#include #include #include "scl/util/str.h" @@ -34,12 +33,7 @@ namespace scl::util { * This type is effectively std::array. */ template -struct Digest { - /** - * @brief The actual type of a digest. - */ - using Type = std::array; -}; +using Digest = std::array; /** * @brief Convert a digest to a string. diff --git a/include/scl/util/hash.h b/include/scl/util/hash.h index 794c1e1..1c147f7 100644 --- a/include/scl/util/hash.h +++ b/include/scl/util/hash.h @@ -20,6 +20,7 @@ #include +#include "scl/util/digest.h" #include "scl/util/sha3.h" namespace scl::util { diff --git a/include/scl/util/iuf_hash.h b/include/scl/util/iuf_hash.h index 43b8120..e1fbef5 100644 --- a/include/scl/util/iuf_hash.h +++ b/include/scl/util/iuf_hash.h @@ -21,9 +21,12 @@ #include #include #include +#include #include #include +#include "scl/serialization/serializers.h" + namespace scl::util { /** @@ -74,6 +77,20 @@ struct IUFHash { string.size()); } + /** + * @brief Update the hash function with the content of a serializable type. + * @param data the data. + * @return the updated Hash object. + */ + template + IUFHash& Update(const T& data) { + using Sr = seri::Serializer; + const auto size = Sr::SizeOf(data); + const auto buf = std::make_unique(size); + Sr::Write(data, buf.get()); + return Update(buf.get(), size); + } + /** * @brief Finalize and return the digest. * @return a digest. diff --git a/include/scl/util/merkle.h b/include/scl/util/merkle.h new file mode 100644 index 0000000..01afffd --- /dev/null +++ b/include/scl/util/merkle.h @@ -0,0 +1,194 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_UTIL_MERKLE_H +#define SCL_UTIL_MERKLE_H + +#include + +#include "scl/util/digest.h" + +namespace scl::util { + +/** + * @brief Merkle hash tree. + * @tparam H a hash function. + * @tparam T the leaf data type. + */ +template +struct MerkleTree { + /** + * @brief The digest type nodes. + */ + using DigestType = typename H::DigestType; + + /** + * @brief Compute a Merkle tree hash. + * @param data the date to hash. + * @return the root hash. + */ + static DigestType Hash(const std::vector& data); + + /** + * @brief A Merkle tree proof. + */ + struct Proof { + /** + * @brief The path from a particular leaf to the root. + */ + std::vector path; + + /** + * @brief A vector describing whether at the left or right element for each + * element in a path. + */ + std::vector direction; + }; + + /** + * @brief Create a proof that a particular index is part of a Merkle tree. + */ + static Proof Prove(const std::vector& data, std::size_t index); + + /** + * @brief Verify a Merkle tree proof. + * @param value the statement. + * @param root the tree root. + * @param proof the proof + * @return true if the + */ + static bool Verify(const T& value, + const DigestType& root, + const Proof& proof); + + private: + static std::vector HashLeafs(const std::vector& data); +}; + +template +auto MerkleTree::HashLeafs(const std::vector& data) + -> std::vector { + std::vector digests; + auto sz = data.size(); + digests.reserve(sz); + + for (const auto& d : data) { + H hash; + digests.emplace_back(hash.Update(d).Finalize()); + } + + // duplicate the last hash in case there's an odd number of leafs. + if (data.size() % 2 == 1) { + digests.emplace_back(digests.back()); + sz++; + } + + return digests; +} // LCOV_EXCL_LINE + +template +auto MerkleTree::Hash(const std::vector& data) -> DigestType { + std::vector digests = HashLeafs(data); + + auto sz = digests.size(); + + while (sz > 1) { + std::size_t j = 0; + for (std::size_t i = 0; i < sz; i += 2) { + const auto left = digests[i]; + const auto right = digests[i + 1]; + H hash; + digests[j] = hash.Update(left).Update(right).Finalize(); + j++; + } + + sz /= 2; + + // Duplicate the last node if there's an odd number of leafs. + if (sz > 1 && sz % 2 == 1) { + digests[j] = digests[j - 1]; + sz++; + } + } + + return digests[0]; +} + +template +auto MerkleTree::Prove(const std::vector& data, std::size_t index) + -> Proof { + std::vector digests = HashLeafs(data); + std::vector path; + std::vector direction; + + auto sz = digests.size(); + + while (sz > 1) { + std::size_t j = 0; + for (std::size_t i = 0; i < sz; i += 2) { + const auto left = digests[i]; + const auto right = digests[i + 1]; + + H hash; + digests[j] = hash.Update(left).Update(right).Finalize(); + + if (i == index) { + path.emplace_back(right); + direction.emplace_back(false); + index = j; + } else if (i + 1 == index) { + path.emplace_back(left); + direction.emplace_back(true); + index = j; + } + + j++; + } + + sz /= 2; + + if (sz > 1 && sz % 2 == 1) { + digests[j] = digests[j - 1]; + sz++; + } + } + + return {path, direction}; +} + +template +bool MerkleTree::Verify(const T& value, + const DigestType& root, + const Proof& proof) { + const auto [h, d] = proof; + + auto digest = H{}.Update(value).Finalize(); + for (std::size_t i = 0; i < h.size(); ++i) { + H hash; + if (d[i]) { + digest = hash.Update(h[i]).Update(digest).Finalize(); + } else { + digest = hash.Update(digest).Update(h[i]).Finalize(); + } + } + + return root == digest; +} + +} // namespace scl::util + +#endif // SCL_UTIL_MERKLE_H diff --git a/include/scl/util/sha256.h b/include/scl/util/sha256.h index 88edcbb..da29caf 100644 --- a/include/scl/util/sha256.h +++ b/include/scl/util/sha256.h @@ -36,7 +36,7 @@ class Sha256 final : public IUFHash { /** * @brief The type of a SHA256 digest. */ - using DigestType = typename Digest<256>::Type; + using DigestType = Digest<256>; /** * @brief Update the hash function with a set of bytes. diff --git a/include/scl/util/sha3.h b/include/scl/util/sha3.h index 26289c9..7919cfd 100644 --- a/include/scl/util/sha3.h +++ b/include/scl/util/sha3.h @@ -41,7 +41,7 @@ class Sha3 final : public IUFHash> { /** * @brief The type of a SHA3 digest. */ - using DigestType = typename Digest::Type; + using DigestType = Digest; /** * @brief Update the hash function with a set of bytes. diff --git a/include/scl/util/sign.h b/include/scl/util/sign.h index 5565e2b..05751d1 100644 --- a/include/scl/util/sign.h +++ b/include/scl/util/sign.h @@ -40,7 +40,7 @@ class ECDSA; template <> struct Signature { private: - using ElementType = math::FF; + using ElementType = math::FF; public: /** @@ -84,18 +84,16 @@ struct Signature { * @brief The ECDSA signature scheme. */ class ECDSA { - using Curve = math::EC; - public: /** * @brief Public key type. A curve point. */ - using PublicKey = Curve; + using PublicKey = math::EC; /** * @brief Secret key type. An element modulo the order of the curve. */ - using SecretKey = Curve::Order; + using SecretKey = PublicKey::ScalarField; /** * @brief Derive the public key correspond to a given secret key. @@ -118,10 +116,10 @@ class ECDSA { static Signature Sign(const SecretKey& secret_key, const D& digest, PRG& prg) { - const auto k = Curve::Order::Random(prg); + const auto k = SecretKey::Random(prg); const auto R = k * PublicKey::Generator(); const auto rx = ConversionFunc(R); - const auto h = DigestToElement(digest); + const auto h = DigestToElement(digest); return {rx, k.Inverse() * (h + secret_key * rx)}; } @@ -137,31 +135,44 @@ class ECDSA { static bool Verify(const PublicKey& public_key, const Signature& signature, const D& digest) { - const auto h = DigestToElement(digest); + const auto h = DigestToElement(digest); const auto [r, s] = signature; const auto si = s.Inverse(); - const auto R1 = (h * si) * Curve::Generator(); + const auto R1 = (h * si) * PublicKey::Generator(); const auto R2 = (r * si) * public_key; const auto R = R1 + R2; return !R.PointAtInfinity() && ConversionFunc(R) == r; } - private: - template - static T DigestToElement(const D& digest) { - if (digest.size() < T::ByteSize()) { - unsigned char buf[T::ByteSize()] = {0}; - std::copy(digest.begin(), digest.end(), buf); - return T::Read(buf); - } - return T::Read(digest.data()); - } - - static Curve::Order ConversionFunc(const PublicKey& R) { + /** + * @brief Computes the ECDSA conversion function. + * @param R the curve point to convert into a scalar field element. + * @return a scalar. + * + * This function computes the \f$C(R)\f$ function that takes curve point + * \f$R=(r_x, r_y)\f$ and outputs a scalar as \f$r_x \mod p\f$ where \f$p\f$ + * is order of a subgroup. + */ + static SecretKey ConversionFunc(const PublicKey& R) { const auto rx_f = R.ToAffine()[0]; - unsigned char rx_bytes[Curve::Field::ByteSize()]; + unsigned char rx_bytes[SecretKey::ByteSize()]; rx_f.Write(rx_bytes); - return Curve::Order::Read(rx_bytes); + return SecretKey::Read(rx_bytes); + } + + /** + * @brief Converts a digest into an element of the scalar field. + * @param digest the digest. + * @return a scalar. + */ + template + static SecretKey DigestToElement(const D& digest) { + if (digest.size() < SecretKey::ByteSize()) { + unsigned char buf[SecretKey::ByteSize()] = {0}; + std::copy(digest.begin(), digest.end(), buf); + return SecretKey::Read(buf); + } + return SecretKey::Read(digest.data()); } }; diff --git a/include/scl/util/traits.h b/include/scl/util/traits.h index 65121a3..5805195 100644 --- a/include/scl/util/traits.h +++ b/include/scl/util/traits.h @@ -19,15 +19,43 @@ #define SCL_UTIL_TRAITS_H #include +#include #include namespace scl::util { +/// @cond + template -struct IsStdVector : std::false_type {}; +struct IsStdVectorImpl : std::false_type {}; template -struct IsStdVector> : std::true_type {}; +struct IsStdVectorImpl> : std::true_type {}; + +// https://stackoverflow.com/a/35207812 +template +struct HasOperatorMulImpl { + template + static auto Test(TT*) -> decltype(std::declval() * std::declval()); + + template + static auto Test(...) -> std::false_type; + + using Type = typename std::is_same(0))>::type; +}; + +/// @endcond + +/** + * @brief Trait for determining if two types can be multipled. + * @tparam T the first type. + * @tparam V the second type. + * + * This trait evalutes to an std::true_type if T operator*(V) is + * defined. + */ +template +struct HasOperatorMul : HasOperatorMulImpl::Type {}; } // namespace scl::util diff --git a/src/scl/math/secp256k1_curve.cc b/src/scl/math/secp256k1_curve.cc index 583ed42..3c310d7 100644 --- a/src/scl/math/secp256k1_curve.cc +++ b/src/scl/math/secp256k1_curve.cc @@ -25,8 +25,10 @@ #include "scl/math/ec_ops.h" #include "scl/math/fp.h" -using Curve = scl::math::Secp256k1; -using Field = scl::math::FF; +using namespace scl; + +using Curve = math::Secp256k1; +using Field = math::FF; using Point = Curve::ValueType; // clang-format off @@ -40,7 +42,7 @@ using Point = Curve::ValueType; static const Field kCurveB(7); template <> -void scl::math::CurveSetPointAtInfinity(Point& out) { +void math::CurveSetPointAtInfinity(Point& out) { out = POINT_AT_INFINITY; } @@ -55,38 +57,36 @@ bool Valid(const Field& x, const Field& y) { } // namespace template <> -void scl::math::CurveSetAffine(Point& out, - const Field& x, - const Field& y) { +void math::CurveSetAffine(Point& out, const Field& x, const Field& y) { if (Valid(x, y)) { - out = {x, y, Field(1)}; + out = {x, y, Field::One()}; } else { throw std::invalid_argument("provided (x, y) not on curve"); } } template <> -std::array scl::math::CurveToAffine(const Point& point) { - const auto Z = GET_Z(point); - return {GET_X(point) / Z, GET_Y(point) / Z}; +std::array math::CurveToAffine(const Point& point) { + const auto Z = GET_Z(point).Inverse(); + return {GET_X(point) * Z, GET_Y(point) * Z}; } template <> -bool scl::math::CurveEqual(const Point& in1, const Point& in2) { - const auto Z1 = GET_Z(in1); - const auto Z2 = GET_Z(in2); +bool math::CurveEqual(const Point& in1, const Point& in2) { + const auto& Z1 = GET_Z(in1); + const auto& Z2 = GET_Z(in2); // (X1, Y1, Z1) eqv (X2, Y2, Z2) <==> (X1 * Z2, Y1 * Z2) == (X2 * Z1, Y2 * Z2) return GET_X(in1) * Z2 == GET_X(in2) * Z1 && GET_Y(in1) * Z2 == GET_Y(in2) * Z1; } template <> -bool scl::math::CurveIsPointAtInfinity(const Point& point) { - return CurveEqual(point, POINT_AT_INFINITY); +bool math::CurveIsPointAtInfinity(const Point& point) { + return GET_Z(point) == Field::Zero(); } template <> -std::string scl::math::CurveToString(const Point& point) { +std::string math::CurveToString(const Point& point) { std::string str; if (CurveIsPointAtInfinity(point)) { str = "EC{POINT_AT_INFINITY}"; @@ -100,7 +100,7 @@ std::string scl::math::CurveToString(const Point& point) { } // LCOV_EXCL_LINE template <> -void scl::math::CurveSetGenerator(Point& out) { +void math::CurveSetGenerator(Point& out) { static const Point gen = { Field::FromString( "79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"), @@ -112,69 +112,97 @@ void scl::math::CurveSetGenerator(Point& out) { } template <> -void scl::math::CurveDouble(Point& out) { - if (!CurveIsPointAtInfinity(out)) { - if (GET_Y(out) == Field::Zero()) { - CurveSetPointAtInfinity(out); - } else if (!CurveIsPointAtInfinity(out)) { - const auto X = GET_X(out); - const auto Y = GET_Y(out); - const auto Z = GET_Z(out); - - const auto W = Field(3) * X * X; - const auto S = Y * Z; - const auto B = X * Y * S; - const auto eight = Field(8); - const auto H = W * W - eight * B; - - out[0] = Field(2) * H * S; - const auto Ssqr = S * S; - out[1] = W * (Field(4) * B - H) - eight * Y * Y * Ssqr; - out[2] = eight * Ssqr * S; - } - } +void math::CurveDouble(Point& out) { + // https://eprint.iacr.org/2015/1060.pdf algorithm 9. + + static const auto b3 = Field(3 * 7); + + auto t0 = GET_Y(out) * GET_Y(out); + auto z3 = t0 + t0; + z3 = z3 + z3; + + z3 = z3 + z3; + auto t1 = GET_Y(out) * GET_Z(out); + auto t2 = GET_Z(out) * GET_Z(out); + + t2 = b3 * t2; + auto x3 = t2 * z3; + auto y3 = t0 + t2; + + z3 = t1 * z3; + t1 = t2 + t2; + t2 = t1 + t2; + + t0 = t0 - t2; + y3 = t0 * y3; + y3 = x3 + y3; + + t1 = GET_X(out) * GET_Y(out); + x3 = t0 * t1; + x3 = x3 + x3; + + out[0] = x3; + out[1] = y3; + out[2] = z3; } template <> -void scl::math::CurveAdd(Point& out, const Point& in) { - if (CurveIsPointAtInfinity(out)) { - out = in; - } else if (!CurveIsPointAtInfinity(in)) { - const auto X1 = GET_X(out); - const auto Y1 = GET_Y(out); - const auto Z1 = GET_Z(out); - const auto X2 = GET_X(in); - const auto Y2 = GET_Y(in); - const auto Z2 = GET_Z(in); - - const auto U1 = Y2 * Z1; - const auto U2 = Y1 * Z2; - const auto V1 = X2 * Z1; - const auto V2 = X1 * Z2; - - if (V1 == V2) { - if (U1 != U2) { - CurveSetPointAtInfinity(out); - } else { - CurveDouble(out); - } - } else { - const auto U = U1 - U2; - const auto V = V1 - V2; - const auto W = Z1 * Z2; - const auto Vsqr = V * V; - const auto VsqrV2 = Vsqr * V2; - const auto Vcbe = Vsqr * V; - const auto A = U * U * W - Vcbe - Field(2) * VsqrV2; - out[0] = V * A; - out[1] = U * (VsqrV2 - A) - Vcbe * U2; - out[2] = Vcbe * W; - } - } +void math::CurveAdd(Point& out, const Point& in) { + // https://eprint.iacr.org/2015/1060.pdf algorithm 7 + + static const auto b3 = Field(3 * 7); + + auto t0 = GET_X(out) * GET_X(in); + auto t1 = GET_Y(out) * GET_Y(in); + auto t2 = GET_Z(out) * GET_Z(in); + + auto t3 = GET_X(out) + GET_Y(out); + auto t4 = GET_X(in) + GET_Y(in); + t3 = t3 * t4; + + t4 = t0 + t1; + t3 = t3 - t4; + t4 = GET_Y(out) + GET_Z(out); + + auto x3 = GET_Y(in) + GET_Z(in); + t4 = t4 * x3; + x3 = t1 + t2; + + t4 = t4 - x3; + x3 = GET_X(out) + GET_Z(out); + auto y3 = GET_X(in) + GET_Z(in); + + x3 = x3 * y3; + y3 = t0 + t2; + y3 = x3 - y3; + + x3 = t0 + t0; + t0 = x3 + t0; + t2 = b3 * t2; + + auto z3 = t1 + t2; + t1 = t1 - t2; + y3 = b3 * y3; + + x3 = t4 * y3; + t2 = t3 * t1; + x3 = t2 - x3; + + y3 = y3 * t0; + t1 = t1 * z3; + y3 = t1 + y3; + + t0 = t0 * t3; + z3 = z3 * t4; + z3 = z3 + t0; + + out[0] = x3; + out[1] = y3; + out[2] = z3; } template <> -void scl::math::CurveNegate(Point& out) { +void math::CurveNegate(Point& out) { if (GET_Y(out) == Field::Zero()) { CurveSetPointAtInfinity(out); } else { @@ -183,14 +211,14 @@ void scl::math::CurveNegate(Point& out) { } template <> -void scl::math::CurveSubtract(Point& out, const Point& in) { +void math::CurveSubtract(Point& out, const Point& in) { Point copy(in); CurveNegate(copy); CurveAdd(out, copy); } template <> -void scl::math::CurveScalarMultiply(Point& out, const Number& scalar) { +void math::CurveScalarMultiply(Point& out, const Number& scalar) { if (!CurveIsPointAtInfinity(out)) { const auto n = scalar.BitSize(); Point res; @@ -207,16 +235,16 @@ void scl::math::CurveScalarMultiply(Point& out, const Number& scalar) { } template <> -void scl::math::CurveScalarMultiply(Point& out, - const FF& scalar) { +void math::CurveScalarMultiply(Point& out, + const FF& scalar) { if (!CurveIsPointAtInfinity(out)) { - auto x = FFAccess::FromMonty(scalar); - const auto n = FFAccess::HigestSetBit(x); + auto x = FFAccess::FromMonty(scalar); + const auto n = FFAccess::HigestSetBit(x); Point res; CurveSetPointAtInfinity(res); for (auto i = n; i-- > 0;) { CurveDouble(res); - if (FFAccess::TestBit(x, i)) { + if (FFAccess::TestBit(x, i)) { CurveAdd(res, out); } } @@ -241,18 +269,18 @@ namespace { Field ComputeOtherCoordinate(const Field& x) { auto y_sqr = x * x * x + kCurveB; - auto z = scl::math::FFAccess::ComputeSqrt(y_sqr); + auto z = math::FFAccess::ComputeSqrt(y_sqr); return z; } bool IsSmaller(const Field& y, const Field& y_neg) { - return scl::math::FFAccess::IsSmaller(y, y_neg); + return math::FFAccess::IsSmaller(y, y_neg); } } // namespace template <> -void scl::math::CurveFromBytes(Point& out, const unsigned char* src) { +void math::CurveFromBytes(Point& out, const unsigned char* src) { const auto flags = *src; if (IS_POINT_AT_INFINITY(flags)) { @@ -290,9 +318,9 @@ void scl::math::CurveFromBytes(Point& out, const unsigned char* src) { #define MARK_SELECT_SMALLER(buf) (*(buf) |= SELECT_SMALLER_FLAG) template <> -void scl::math::CurveToBytes(unsigned char* dest, - const Point& in, - bool compress) { +void math::CurveToBytes(unsigned char* dest, + const Point& in, + bool compress) { // Make sure flag byte is zeroed. *dest = 0; @@ -313,7 +341,7 @@ void scl::math::CurveToBytes(unsigned char* dest, // x and y. if (compress) { // include a flag which indicates which of {y, -y} is the smaller. - const auto y = ap[1]; + const auto& y = ap[1]; const auto yn = y.Negated(); if (IsSmaller(y, yn)) { diff --git a/src/scl/math/secp256k1_field.cc b/src/scl/math/secp256k1_field.cc index 43aa0a0..877bde8 100644 --- a/src/scl/math/secp256k1_field.cc +++ b/src/scl/math/secp256k1_field.cc @@ -36,21 +36,27 @@ using Elem = Field::ValueType; } \ } while (0) -// The prime modulus p -static const mp_limb_t kPrime[] = { - 0xFFFFFFFEFFFFFC2F, // - 0xFFFFFFFFFFFFFFFF, // - 0xFFFFFFFFFFFFFFFF, // - 0xFFFFFFFFFFFFFFFF // -}; - -// n' such that 2^{256} * a + kPrime * n' == 1 -static const mp_limb_t kMontyN[] = { - 0x27C7F6E22DDACACF, // - 0x434DDC0123DB5FA6, // - 0x63B93D3D6A0D489E, // - 0x3642E6FAEAAC7C66 // -}; +static const scl::math::RedParams RD = { + // Prime + { + 0xFFFFFFFEFFFFFC2F, // + 0xFFFFFFFFFFFFFFFF, // + 0xFFFFFFFFFFFFFFFF, // + 0xFFFFFFFFFFFFFFFF // + }, + // Montgomery constant + { + 0xD838091DD2253531, // + 0xBCB223FEDC24A059, // + 0x9C46C2C295F2B761, // + 0xC9BD190515538399 // + }}; + +template <> +scl::math::Number scl::math::Order>() { + return Number::FromString( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F"); +} // The internal data type is an STL array, but gmp expects pointers. #define PTR(X) (X).data() @@ -58,27 +64,27 @@ static const mp_limb_t kMontyN[] = { template <> void scl::math::FieldConvertIn(Elem& out, const int value) { out = {0}; - MontyInFromInt(PTR(out), value, kPrime); + MontyInFromInt(PTR(out), value, RD); } template <> void scl::math::FieldAdd(Elem& out, const Elem& op) { - MontyModAdd(PTR(out), PTR(op), kPrime); + MontyModAdd(PTR(out), PTR(op), RD); } template <> void scl::math::FieldSubtract(Elem& out, const Elem& op) { - MontyModSub(PTR(out), PTR(op), kPrime); + MontyModSub(PTR(out), PTR(op), RD); } template <> void scl::math::FieldNegate(Elem& out) { - MontyModNeg(PTR(out), kPrime); + MontyModNeg(PTR(out), RD); } template <> void scl::math::FieldMultiply(Elem& out, const Elem& op) { - MontyModMul(PTR(out), PTR(op), kPrime, kMontyN); + MontyModMul(PTR(out), PTR(op), RD); } #define ONE \ @@ -94,7 +100,7 @@ void scl::math::FieldInvert(Elem& out) { }; Elem res = ONE; - MontyModInv(PTR(res), PTR(out), kPrime, kPrimeMinus2, kMontyN); + MontyModInv(PTR(res), PTR(out), kPrimeMinus2, RD); out = res; } @@ -105,23 +111,23 @@ bool scl::math::FieldEqual(const Elem& in1, const Elem& in2) { template <> void scl::math::FieldFromBytes(Elem& dest, const unsigned char* src) { - MontyFromBytes(PTR(dest), src, kPrime); + MontyFromBytes(PTR(dest), src, RD); } template <> void scl::math::FieldToBytes(unsigned char* dest, const Elem& src) { - MontyToBytes(dest, PTR(src), kPrime, kMontyN); + MontyToBytes(dest, PTR(src), RD); } template <> void scl::math::FieldFromString(Elem& out, const std::string& src) { out = {0}; - MontyFromString(PTR(out), kPrime, src); + MontyFromString(PTR(out), src, RD); } template <> std::string scl::math::FieldToString(const Elem& in) { - return MontyToString(PTR(in), kPrime, kMontyN); + return MontyToString(PTR(in), RD); } bool scl::math::FFAccess::IsSmaller( @@ -144,7 +150,7 @@ scl::math::FF scl::math::FFAccess::ComputeSqrt( FF out; Elem res = ONE; - MontyModExp(PTR(res), PTR(x.m_value), e, kPrime, kMontyN); + MontyModExp(PTR(res), PTR(x.m_value), e, RD); out.m_value = res; return out; } // LCOV_EXCL_LINE diff --git a/src/scl/math/secp256k1_helpers.h b/src/scl/math/secp256k1_helpers.h index 583829d..3053e04 100644 --- a/src/scl/math/secp256k1_helpers.h +++ b/src/scl/math/secp256k1_helpers.h @@ -44,23 +44,23 @@ struct FFAccess { * @brief Helper class for Secp256k1::Order. */ template <> -struct FFAccess { +struct FFAccess { /** * @brief Convert a field element out of montgomery representation. */ - static FF FromMonty(const FF& element); + static FF FromMonty(const FF& element); /** * @brief Find the position of the highest set bit. */ - static std::size_t HigestSetBit(const FF& element); + static std::size_t HigestSetBit(const FF& element); /** * @brief Check if a particular bit is set. * * \p pos is assumed to be at or below HighestSetBit(\p element). */ - static bool TestBit(const FF& element, std::size_t pos); + static bool TestBit(const FF& element, std::size_t pos); }; } // namespace scl::math diff --git a/src/scl/math/secp256k1_order.cc b/src/scl/math/secp256k1_scalar.cc similarity index 75% rename from src/scl/math/secp256k1_order.cc rename to src/scl/math/secp256k1_scalar.cc index 7ddcbae..91725fd 100644 --- a/src/scl/math/secp256k1_order.cc +++ b/src/scl/math/secp256k1_scalar.cc @@ -27,7 +27,7 @@ #include "scl/math/ff_ops.h" #include "scl/math/ops_gmp_ff.h" -using Field = scl::math::Secp256k1::Order; +using Field = scl::math::Secp256k1::Scalar; using Elem = Field::ValueType; #define NUM_LIMBS 4 @@ -39,46 +39,54 @@ using Elem = Field::ValueType; } \ } while (0) -static const mp_limb_t kPrime[] = { - 0xBFD25E8CD0364141, // - 0xBAAEDCE6AF48A03B, // - 0xFFFFFFFFFFFFFFFE, // - 0xFFFFFFFFFFFFFFFF // -}; +template <> +scl::math::Number scl::math::Order>() { + return Number::FromString( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); +} -static const mp_limb_t kMontyN[] = { - 0xB4F20099AA774EC1, // - 0xAF5AE537CB4613DB, // - 0x7680CF3ED83054A1, // - 0x261776F29B6B106C // -}; +static const scl::math::RedParams RD = { + // Prime + { + 0xBFD25E8CD0364141, // + 0xBAAEDCE6AF48A03B, // + 0xFFFFFFFFFFFFFFFE, // + 0xFFFFFFFFFFFFFFFF // + }, + // Montgomery constant + { + 0x4B0DFF665588B13F, // + 0x50A51AC834B9EC24, // + 0x897F30C127CFAB5E, // + 0xD9E8890D6494EF93 // + }}; #define PTR(X) (X).data() template <> void scl::math::FieldConvertIn(Elem& out, const int value) { out = {0}; - MontyInFromInt(PTR(out), value, kPrime); + MontyInFromInt(PTR(out), value, RD); } template <> void scl::math::FieldAdd(Elem& out, const Elem& op) { - MontyModAdd(PTR(out), PTR(op), kPrime); + MontyModAdd(PTR(out), PTR(op), RD); } template <> void scl::math::FieldSubtract(Elem& out, const Elem& op) { - MontyModSub(PTR(out), PTR(op), kPrime); + MontyModSub(PTR(out), PTR(op), RD); } template <> void scl::math::FieldNegate(Elem& out) { - MontyModNeg(PTR(out), kPrime); + MontyModNeg(PTR(out), RD); } template <> void scl::math::FieldMultiply(Elem& out, const Elem& op) { - MontyModMul(PTR(out), PTR(op), kPrime, kMontyN); + MontyModMul(PTR(out), PTR(op), RD); } #define ONE \ @@ -94,7 +102,7 @@ void scl::math::FieldInvert(Elem& out) { }; Elem res = ONE; - MontyModInv(PTR(res), PTR(out), kPrime, kPrimeMinus2, kMontyN); + MontyModInv(PTR(res), PTR(out), kPrimeMinus2, RD); out = res; } @@ -105,23 +113,23 @@ bool scl::math::FieldEqual(const Elem& in1, const Elem& in2) { template <> void scl::math::FieldFromBytes(Elem& dest, const unsigned char* src) { - MontyFromBytes(PTR(dest), src, kPrime); + MontyFromBytes(PTR(dest), src, RD); } template <> void scl::math::FieldToBytes(unsigned char* dest, const Elem& src) { - MontyToBytes(dest, PTR(src), kPrime, kMontyN); + MontyToBytes(dest, PTR(src), RD); } template <> std::string scl::math::FieldToString(const Elem& in) { - return MontyToString(PTR(in), kPrime, kMontyN); + return MontyToString(PTR(in), RD); } template <> void scl::math::FieldFromString(Elem& out, const std::string& src) { out = {0}; - MontyFromString(PTR(out), kPrime, src); + MontyFromString(PTR(out), src, RD); } std::size_t scl::math::FFAccess::HigestSetBit( @@ -141,7 +149,7 @@ scl::math::FF scl::math::FFAccess::FromMonty( const scl::math::FF& element) { mp_limb_t padded[2 * NUM_LIMBS] = {0}; SCL_COPY(padded, PTR(element.m_value), NUM_LIMBS); - MontyRedc(padded, kPrime, kMontyN); + MontyRedc(padded, RD); FF r; SCL_COPY(PTR(r.m_value), padded, NUM_LIMBS); diff --git a/src/scl/simulation/channel.cc b/src/scl/simulation/channel.cc index fd196c1..118eea1 100644 --- a/src/scl/simulation/channel.cc +++ b/src/scl/simulation/channel.cc @@ -27,8 +27,7 @@ using EventPtr = std::shared_ptr; -EventPtr scl::sim::SimulateClose(std::shared_ptr ctx, - ChannelId id) { +EventPtr scl::sim::SimulateClose(std::shared_ptr ctx, ChannelId id) { const auto lid = id.local; const auto trt = ctx->Checkpoint(lid); return std::make_shared(Event::Type::CLOSE, trt, id); @@ -37,13 +36,13 @@ EventPtr scl::sim::SimulateClose(std::shared_ptr ctx, #define SCL_LOCAL_COMP_BEGIN const auto scl__lcb = scl::util::Time::Now() #define SCL_LOCAL_COMP_END scl::util::Time::Now() - scl__lcb -EventPtr scl::sim::SimulateSend(std::shared_ptr ctx, +EventPtr scl::sim::SimulateSend(std::shared_ptr ctx, ChannelId id, const unsigned char* src, std::size_t n) { SCL_LOCAL_COMP_BEGIN; - ctx->Buffer(id)->Write({src, src + n}); + ctx->Buffer(id)->Write(src, n); const auto local_comp_time = SCL_LOCAL_COMP_END; const auto exec_time = ctx->Checkpoint(id.local) - local_comp_time; @@ -51,46 +50,36 @@ EventPtr scl::sim::SimulateSend(std::shared_ptr ctx, auto event = std::make_shared(Event::Type::SEND, exec_time, id, n); ctx->AddCandidateToRun(id.remote); - ctx->RecordWrite(id, n, exec_time); + ctx->AddWrite(id, n, exec_time); return event; } namespace { -scl::util::Time::Duration AdjustRecvTime( - std::shared_ptr ctx, - scl::sim::ChannelId id, - scl::util::Time::Duration t, - std::size_t n) { +scl::util::Time::Duration AdjustRecvTime(std::shared_ptr ctx, + scl::sim::ChannelId id, + scl::util::Time::Duration t, + std::size_t n) { auto rem = n; - auto wb = ctx->Writes(id).begin(); - auto we = ctx->Writes(id).end(); - - while (rem > 0 && wb != we) { - // TODO: It would probably be nicer to let ctx clean up the list of write - // ops so this check isn't necessary. - - if (wb->amount == 0) { - wb++; - continue; - } + while (rem > 0 && ctx->HasWrite(id)) { + auto& w = ctx->NextWrite(id); scl::util::Time::Duration recv_time; - if (wb->amount >= rem) { - const auto delay = scl::sim::ComputeRecvTime(ctx->NetworkConfig(id), rem); - recv_time = wb->time + delay; - wb->amount -= rem; + if (w.amount > rem) { + const auto delay = + scl::sim::ComputeRecvTime(ctx->ChannelConfiguration(id), rem); + recv_time = w.time + delay; + w.amount -= rem; rem = 0; - } else /* wb->amount < rem */ { + } else { const auto delay = - scl::sim::ComputeRecvTime(ctx->NetworkConfig(id), wb->amount); - recv_time = wb->time + delay; - rem -= wb->amount; - wb->amount = 0; + scl::sim::ComputeRecvTime(ctx->ChannelConfiguration(id), w.amount); + recv_time = w.time + delay; + rem -= w.amount; + ctx->DeleteWrite(id); } t = std::max(t, recv_time); - wb++; } return t; @@ -98,7 +87,7 @@ scl::util::Time::Duration AdjustRecvTime( } // namespace -EventPtr scl::sim::SimulateRecv(std::shared_ptr ctx, +EventPtr scl::sim::SimulateRecv(std::shared_ptr ctx, ChannelId id, unsigned char* dst, std::size_t n) { @@ -109,8 +98,7 @@ EventPtr scl::sim::SimulateRecv(std::shared_ptr ctx, throw SimulationFailure(); } - auto data = ctx->Buffer(id)->Read(n); - std::copy(data.begin(), data.end(), dst); + ctx->Buffer(id)->Read(dst, n); const auto local_comp_time = SCL_LOCAL_COMP_END; const auto exec_time = ctx->Checkpoint(id.local) - local_comp_time; @@ -124,49 +112,85 @@ EventPtr scl::sim::SimulateRecv(std::shared_ptr ctx, } std::pair scl::sim::SimulateHasData( - std::shared_ptr ctx, + std::shared_ptr ctx, ChannelId id) { // The other party hasn't had a chance to run yet, so it's not possible to // determine if there's data available for us. if (ctx->Trace(id.remote).empty()) { ctx->AddCandidateToRun(id.remote); - throw SimulationFailure(); + throw SimulationFailure("other party hasnt started yet"); } - const auto other_latest = ctx->LatestTimestamp(id.remote); + // We determine if there is data available by inspecting the list of WriteOps + // created by the remote party. Since each WriteOp has a timestamp, we can use + // that to determine if the data would have arrived at us yet. + // + // The rules for what to return, and when to fail the simulation goes as + // follows: + // + // - WriteOp op exists such that op.amount > 0. This op corresponds to the + // data that we would receive the next time we call Recv on this channel. + // + // If it is the case that + // + // op.time + time_to_send_1_byte <= our_current_time, + // + // then we can return has_data == true. Otherwise, we can return false. + // Note that, even if the remote party is behind is in time, we know that + // it is not possible for it to send data that we would receive earlier + // than the data connected to op. + // + // - No WriteOp exists. In this case, we either return has_data == false, or + // we fail the simulation. We can return has_data == false if + // + // remote_current_time - time_to_send_1_byte >= our_current_time + // + // as we know that no Send that the remote party makes, would have arrived + // to us before now. On the other hand, if the above does not hold, then we + // cannot say for sure that the remote party might not send data that we + // would be able to receive now, and so we have to fail the simulation. + + // Time it takes for 1 byte to go from the remote party to us. + const auto offset = ComputeRecvTime(ctx->ChannelConfiguration(id.Flip()), 1); + + // Go through each write op of the other party, and find the earliest one. const auto me_latest = ctx->Checkpoint(id.local); - - // The other party is still running, but is chronologically behind us, so it's - // not possible to determine if there's data available for us. - if (ctx->Trace(id.remote).back()->EventType() != Event::Type::STOP && - other_latest < me_latest) { - ctx->AddCandidateToRun(id.remote); - throw SimulationFailure(); - } - - // Check all handled writes of the other party. If there's one which took - // place before me_latest that we haven't read yet, then there's data - // available. Note that ordering of the writes do not matter here. We're just - // interested in some unhandled write. bool has_data = false; - for (const auto& wop : ctx->Writes(id.Flip())) { - if (wop.time <= me_latest && wop.amount > 0) { + bool has_result = false; + if (ctx->HasWrite(id.Flip())) { + if (ctx->NextWrite(id.Flip()).time + offset <= me_latest) { has_data = true; - break; + } else { + has_data = false; + has_result = true; + } + } + + // Handle the case where no WriteOp existed at all. Here we will fail the + // simulation if the remote party is too far behind us in time. + if (!has_data && !has_result) { + const auto other_latest = ctx->LatestTimestamp(id.remote) - offset; + if (!ctx->HasTerminated(id.remote) && other_latest <= me_latest) { + ctx->AddCandidateToRun(id.remote); + throw SimulationFailure("no data, and we're ahead"); } } const auto event = std::make_shared(me_latest, id, has_data); - ctx->AddEvent(id.local, event); return {has_data, event}; } -void scl::sim::SimulatedChannel::Send(const scl::net::Packet& packet) { +void scl::sim::Channel::Send(const scl::net::Packet& packet) { const auto packet_size = packet.Size(); const auto size_size = sizeof(net::Packet::SizeType); + // A packet is a size + content, which are sent separately. scl::net::Channel::Send(packet); + // Sending the size and conte each generate a "SEND" event. These are removed + // here, and replaced by a single "PACKET_SEND" event that is set to have + // happened at the same time as the first event, and with an amount equal to + // the sum of the two events. const auto data_event = m_ctx->PopLastEvent(m_id.local); const auto size_event = m_ctx->PopLastEvent(m_id.local); const auto event = @@ -185,10 +209,19 @@ std::size_t GetDataAmount(scl::sim::Event* event) { } // namespace -std::optional scl::sim::SimulatedChannel::Recv(bool block) { +std::optional scl::sim::Channel::Recv(bool block) { + // A packet is received a little differently, depending on whether it blocks + // or not. If the recv is blocking, then we receive a size + content. If the + // receive is non-blocking, then we first check if there's data before + // receiving the size + content. + auto p = net::Channel::Recv(block); if (block) { + // Receive was blocking, so we need to remove the two last events, + // corresponding to the receiving the size of the packet, and the packet's + // content. The information in these two events is then turned into a + // PACKET_RECV event. const auto data_event = m_ctx->PopLastEvent(m_id.local); const auto size_event = m_ctx->PopLastEvent(m_id.local); @@ -201,6 +234,12 @@ std::optional scl::sim::SimulatedChannel::Recv(bool block) { GetDataAmount(data_event.get()) + GetDataAmount(size_event.get()), true)); } else { + // If the receive was non-blocking, then we either have one event (in case + // there was no data to receive), or three (in case there was data to + // receive). + // + // The extra event here, compared to the blocking case, is an event arising + // from a call to HasData. if (p.has_value()) { const auto data_event = m_ctx->PopLastEvent(m_id.local); const auto size_event = m_ctx->PopLastEvent(m_id.local); diff --git a/src/scl/simulation/config.cc b/src/scl/simulation/config.cc index 1e32edb..a9e36b4 100644 --- a/src/scl/simulation/config.cc +++ b/src/scl/simulation/config.cc @@ -20,7 +20,9 @@ #include #include -void scl::sim::SimulatedNetworkConfig::Builder::Validate() const { +using namespace scl; + +void sim::ChannelConfig::Builder::Validate() const { if (m_bandwidth.has_value()) { if (m_bandwidth.value() == 0) { throw std::invalid_argument("bandwidth cannot be 0"); @@ -49,28 +51,26 @@ void scl::sim::SimulatedNetworkConfig::Builder::Validate() const { } } -std::ostream& scl::sim::operator<<(std::ostream& os, - const SimulatedNetworkConfig& config) { - os << "SimulationConfig{"; - os << "Bandwidth: " << config.Bandwidth() << " bits/s, "; - os << "RTT: " << config.RTT() << " ms, "; - os << "MSS: " << config.MSS() << " bytes, "; - os << "PackageLoss: " << 100 * config.PackageLoss() << "%, "; - os << "WindowSize: " << config.WindowSize() << " bytes}"; +std::ostream& sim::operator<<(std::ostream& os, const ChannelConfig& config) { + if (config.Type() == sim::ChannelConfig::NetworkType::TCP) { + os << "SimulationConfig{"; + os << "Type: TCP, "; + os << "Bandwidth: " << config.Bandwidth() << " bits/s, "; + os << "RTT: " << config.RTT() << " ms, "; + os << "MSS: " << config.MSS() << " bytes, "; + os << "PackageLoss: " << 100 * config.PackageLoss() << "%, "; + os << "WindowSize: " << config.WindowSize() << " bytes}"; + } else { + os << "SimulationConfig{INSTANT}"; + } return os; } -scl::sim::SimulatedNetworkConfig scl::sim::SimulatedNetworkConfig::Default() { - return SimulatedNetworkConfig::Builder{}.Build(); +sim::ChannelConfig sim::ChannelConfig::Default() { + return ChannelConfig::Builder{}.Build(); } -scl::sim::SimulatedNetworkConfig scl::sim::SimulatedNetworkConfig::Loopback() { - return SimulatedNetworkConfig::Builder{} - .Bandwidth(-1) - .MSS(-1) - .PackageLoss(0) - .RTT(0) - .WindowSize(-1) - .Build(); +sim::ChannelConfig sim::ChannelConfig::Loopback() { + return ChannelConfig::Builder{}.Type(NetworkType::INSTANT).Build(); } diff --git a/src/scl/simulation/context.cc b/src/scl/simulation/context.cc index a33dc0a..832f7a0 100644 --- a/src/scl/simulation/context.cc +++ b/src/scl/simulation/context.cc @@ -17,16 +17,19 @@ #include "scl/simulation/context.h" +#include "scl/simulation/config.h" +#include "scl/simulation/event.h" #include "scl/simulation/mem_channel_buffer.h" #include "scl/simulation/simulator.h" -using SimCtx = scl::sim::SimulationContext; +using namespace scl; template <> -std::shared_ptr SimCtx::Create( +std::shared_ptr +sim::Context::Create( std::size_t number_of_parties, - const SimulatedNetworkConfigCreator& config) { - auto ctx = std::make_shared(config); + std::shared_ptr config) { + auto ctx = std::make_shared(config); ctx->m_nparties = number_of_parties; ctx->m_traces.resize(number_of_parties); @@ -51,14 +54,9 @@ std::size_t Next(std::size_t id, std::size_t n) { return (id + 1) % n; } -bool HasTerminated(const scl::sim::SimulationTrace& trace) { - return !trace.empty() && - trace.back()->EventType() == scl::sim::Event::Type::STOP; -} - } // namespace -std::optional SimCtx::NextToRun( +std::optional sim::Context::NextToRun( std::optional current) { // party 0 is always the party to go first. if (!current.has_value()) { @@ -70,11 +68,11 @@ std::optional SimCtx::NextToRun( if (m_state == State::ROLLBACK) { // the last party in m_next_party_cand is assumed to be the party for // which current tried to Recv or HasData from. - auto next = m_next_party_cand.back(); + const auto next = m_next_party_cand.back(); // if this party has already finished, then current will never be able to // finish, so we crash the simulation here. - if (HasTerminated(m_traces[next])) { + if (HasTerminated(next)) { throw SimulationFailure( "party tried to receive data from terminated party"); } @@ -85,12 +83,14 @@ std::optional SimCtx::NextToRun( if (next == current) { throw SimulationFailure("infinite loop detected"); } + + return next; } std::size_t next = Next(current.value(), m_nparties); std::size_t terminated = 0; while (terminated < m_nparties) { - if (!HasTerminated(m_traces[next])) { + if (!HasTerminated(next)) { return next; } terminated++; @@ -100,14 +100,14 @@ std::optional SimCtx::NextToRun( return {}; } -scl::util::Time::Duration SimCtx::Checkpoint(std::size_t id) { +util::Time::Duration sim::Context::Checkpoint(std::size_t id) { const auto latest = LatestTimestamp(id); const auto last_checkpoint = m_checkpoint; UpdateCheckpoint(); return latest + (m_checkpoint - last_checkpoint); } -void SimCtx::Prepare(std::size_t id) { +void sim::Context::Prepare(std::size_t id) { if (m_state == State::COMMIT || m_state == State::ROLLBACK) { // Save the current head of m_traces so we can discard new events if this // party has to rollback. @@ -128,7 +128,7 @@ void SimCtx::Prepare(std::size_t id) { m_state = State::PREPARE; } -void SimCtx::Commit(std::size_t id) { +void sim::Context::Commit(std::size_t id) { if (m_state == State::PREPARE) { m_writes_backup.clear(); for (std::size_t i = 0; i < m_nparties; ++i) { @@ -142,7 +142,7 @@ void SimCtx::Commit(std::size_t id) { m_state = State::COMMIT; } -void SimCtx::Rollback(std::size_t id) { +void sim::Context::Rollback(std::size_t id) { if (m_state == State::PREPARE) { m_traces[id].resize(m_trace_index); m_writes = m_writes_backup; diff --git a/src/scl/simulation/event.cc b/src/scl/simulation/event.cc index 267f614..99385a8 100644 --- a/src/scl/simulation/event.cc +++ b/src/scl/simulation/event.cc @@ -20,87 +20,91 @@ #include #include -namespace { +using namespace scl; -using Evt = scl::sim::Event; +namespace { -auto EventTypeToString(Evt::Type type) { - if (type == Evt::Type::START) { +auto EventTypeToString(sim::Event::Type type) { + if (type == sim::Event::Type::START) { return "START"; } - if (type == Evt::Type::STOP) { + if (type == sim::Event::Type::STOP) { return "STOP"; } - if (type == Evt::Type::SEND) { + if (type == sim::Event::Type::SEND) { return "SEND"; } - if (type == Evt::Type::RECV) { + if (type == sim::Event::Type::RECV) { return "RECV"; } - if (type == Evt::Type::HAS_DATA) { + if (type == sim::Event::Type::HAS_DATA) { return "HAS_DATA"; } - if (type == Evt::Type::OUTPUT) { + if (type == sim::Event::Type::OUTPUT) { return "OUTPUT"; } - if (type == Evt::Type::SLEEP) { + if (type == sim::Event::Type::SLEEP) { return "SLEEP"; } - if (type == Evt::Type::SEGMENT_BEGIN) { + if (type == sim::Event::Type::SEGMENT_BEGIN) { return "SEGMENT_BEGIN"; } - if (type == Evt::Type::SEGMENT_END) { + if (type == sim::Event::Type::SEGMENT_END) { return "SEGMENT_END"; } - if (type == Evt::Type::CHECKPOINT) { + if (type == sim::Event::Type::CHECKPOINT) { return "CHECKPOINT"; } - if (type == Evt::Type::PACKET_SEND) { + if (type == sim::Event::Type::PACKET_SEND) { return "PACKET_SEND"; } - if (type == Evt::Type::PACKET_RECV) { + if (type == sim::Event::Type::PACKET_RECV) { return "PACKET_RECV"; } - // if (type == scl::Measurement::Type::CLOSE) + if (type == sim::Event::Type::KILLED) { + return "KILLED"; + } + + // if (type == Measurement::Type::CLOSE) return "CLOSE"; } -void WriteClose(std::ostream& os, const scl::sim::NetworkEvent* m) { +void WriteClose(std::ostream& os, const sim::NetworkEvent* m) { os << " [Local=" << m->LocalParty() << ", Remote=" << m->RemoteParty() << "]"; } -void WriteSend(std::ostream& os, const scl::sim::NetworkDataEvent* m) { +void WriteSend(std::ostream& os, const sim::NetworkDataEvent* m) { os << " [" << "Sender=" << m->LocalParty() << ", Receiver=" << m->RemoteParty() << ", Amount=" << m->DataAmount() << "]"; } -void WriteRecv(std::ostream& os, const scl::sim::NetworkDataEvent* m) { +void WriteRecv(std::ostream& os, const sim::NetworkDataEvent* m) { os << " [" << "Receiver=" << m->LocalParty() << ", Sender=" << m->RemoteParty() << ", Amount=" << m->DataAmount() << "]"; } -void WritePacketRecv(std::ostream& os, const scl::sim::PacketRecvEvent* m) { +void WritePacketRecv(std::ostream& os, const sim::PacketRecvEvent* m) { os << " [" << "Receiver=" << m->LocalParty() << ", Sender=" << m->RemoteParty() << ", Amount=" << m->DataAmount() << ", Blocking=" << std::boolalpha << m->Blocking() << "]"; } -void WriteSegment(std::ostream& os, const scl::sim::SegmentEvent* m) { +void WriteSegment(std::ostream& os, const sim::SegmentEvent* m) { const auto name = m->Name(); if (name.empty()) { os << " [Unnamed segment]"; @@ -109,21 +113,25 @@ void WriteSegment(std::ostream& os, const scl::sim::SegmentEvent* m) { } } -void WriteHasData(std::ostream& os, const scl::sim::HasDataEvent* m) { +void WriteHasData(std::ostream& os, const sim::HasDataEvent* m) { os << " [Local=" << m->LocalParty() << ", Remote=" << m->RemoteParty() << ", DataAvailable=" << std::boolalpha << m->HadData() << "]"; } -void WriteCheckpoint(std::ostream& os, const scl::sim::CheckpointEvent* m) { - os << " [" << m->Message() << "]"; +void WriteCheckpoint(std::ostream& os, const sim::CheckpointEvent* m) { + os << " [" << m->Id() << "]"; } } // namespace -std::ostream& scl::sim::operator<<(std::ostream& os, const Evt* m) { +std::ostream& sim::operator<<(std::ostream& os, Event::Type type) { + return os << EventTypeToString(type); +} + +std::ostream& sim::operator<<(std::ostream& os, const sim::Event* m) { using namespace std::chrono; const auto t = m->EventType(); - os << EventTypeToString(t) << " at "; + os << t << " at "; os << duration(m->Timestamp()).count(); os << " ms"; if (m->Offset() > util::Time::Duration::zero()) { @@ -132,31 +140,32 @@ std::ostream& scl::sim::operator<<(std::ostream& os, const Evt* m) { os << " ms]"; } - if (t == Evt::Type::SEGMENT_BEGIN || t == Evt::Type::SEGMENT_END) { + if (t == sim::Event::Type::SEGMENT_BEGIN || + t == sim::Event::Type::SEGMENT_END) { WriteSegment(os, dynamic_cast(m)); } - if (t == Evt::Type::CLOSE) { + if (t == sim::Event::Type::CLOSE) { WriteClose(os, dynamic_cast(m)); } - if (t == Evt::Type::SEND || t == Evt::Type::PACKET_SEND) { + if (t == sim::Event::Type::SEND || t == sim::Event::Type::PACKET_SEND) { WriteSend(os, dynamic_cast(m)); } - if (t == Evt::Type::RECV) { + if (t == sim::Event::Type::RECV) { WriteRecv(os, dynamic_cast(m)); } - if (t == Evt::Type::PACKET_RECV) { + if (t == sim::Event::Type::PACKET_RECV) { WritePacketRecv(os, dynamic_cast(m)); } - if (t == Evt::Type::HAS_DATA) { + if (t == sim::Event::Type::HAS_DATA) { WriteHasData(os, dynamic_cast(m)); } - if (t == Evt::Type::CHECKPOINT) { + if (t == sim::Event::Type::CHECKPOINT) { WriteCheckpoint(os, dynamic_cast(m)); } diff --git a/src/scl/simulation/measurement.cc b/src/scl/simulation/measurement.cc index 2ae3a42..4de04d7 100644 --- a/src/scl/simulation/measurement.cc +++ b/src/scl/simulation/measurement.cc @@ -17,10 +17,67 @@ #include "scl/simulation/measurement.h" -std::ostream& scl::sim::operator<<(std::ostream& os, - const scl::sim::TimeMeasurement& m) { - auto mean = std::chrono::duration(m.Mean()).count(); - auto std_dev = std::chrono::duration(m.StdDev()).count(); +#include + +using namespace scl; + +namespace { + +template +T Zero() { + return 0; +} + +template <> +util::Time::Duration Zero() { + return util::Time::Duration::zero(); +} + +template +T Mean(const sim::Measurement& m) { + T sum = Zero(); + for (const auto& v : m.Samples()) { + sum += v; + } + return sum / m.Size(); +} + +long double Sqrt(long double v) { + return std::sqrt(v); +} + +long double Sqr(long double v) { + return v * v; +} + +util::Time::Duration Sqrt(const util::Time::Duration& v) { + long double u = std::sqrt(v.count()); + std::chrono::duration w(u); + return std::chrono::duration_cast(w); +} + +util::Time::Duration Sqr(const util::Time::Duration& v) { + long double u = v.count(); + std::chrono::duration w(u * u); + return std::chrono::duration_cast(w); +} + +template +T StdDev(const sim::Measurement& m) { + const auto mu = Mean(m); + auto sum = Zero(); + for (const auto& v : m.Samples()) { + sum += Sqr(v - mu); + } + return Sqrt(sum / m.Size()); +} + +} // namespace + +std::ostream& sim::operator<<(std::ostream& os, const sim::TimeMeasurement& m) { + const auto mean = std::chrono::duration(Mean(m)).count(); + const auto std_dev = + std::chrono::duration(StdDev(m)).count(); os << "{" << "\"mean\": " << mean << ", " @@ -30,12 +87,14 @@ std::ostream& scl::sim::operator<<(std::ostream& os, return os; } -std::ostream& scl::sim::operator<<(std::ostream& os, - const scl::sim::DataMeasurement& m) { +std::ostream& sim::operator<<(std::ostream& os, const sim::DataMeasurement& m) { + const auto mean = Mean(m); + const auto std_dev = StdDev(m); + os << "{" - << "\"mean\": " << m.Mean() << ", " + << "\"mean\": " << mean << ", " << "\"unit\": \"B\", " - << "\"std_dev\": " << m.StdDev() << "}"; + << "\"std_dev\": " << std_dev << "}"; return os; } diff --git a/src/scl/simulation/result.cc b/src/scl/simulation/result.cc index 910698d..7eed414 100644 --- a/src/scl/simulation/result.cc +++ b/src/scl/simulation/result.cc @@ -32,6 +32,8 @@ #include "scl/simulation/measurement.h" #include "scl/util/time.h" +using namespace scl; + namespace { struct SentRecv { @@ -41,27 +43,31 @@ struct SentRecv { using SentRecvMap = std::unordered_map; +using CheckpointMap = std::unordered_map; + struct Segment { // sent/recv to other parties SentRecvMap sr; // execution time of the segment - scl::util::Time::Duration dur; + util::Time::Duration dur; + // checkpoints found in the segment + CheckpointMap checkpoints; }; -std::string GetNameFromSegmentEvent(std::shared_ptr event) { - return std::dynamic_pointer_cast(event)->Name(); +std::string GetNameFromSegmentEvent(std::shared_ptr event) { + return std::dynamic_pointer_cast(event)->Name(); } using NamedSegment = std::pair; -bool IsRecvEvent(std::shared_ptr ptr) { - return ptr->EventType() == scl::sim::Event::Type::RECV || - ptr->EventType() == scl::sim::Event::Type::PACKET_RECV; +bool IsRecvEvent(std::shared_ptr ptr) { + return ptr->EventType() == sim::Event::Type::RECV || + ptr->EventType() == sim::Event::Type::PACKET_RECV; } -bool IsSendEvent(std::shared_ptr ptr) { - return ptr->EventType() == scl::sim::Event::Type::SEND || - ptr->EventType() == scl::sim::Event::Type::PACKET_SEND; +bool IsSendEvent(std::shared_ptr ptr) { + return ptr->EventType() == sim::Event::Type::SEND || + ptr->EventType() == sim::Event::Type::PACKET_SEND; } /** @@ -73,7 +79,7 @@ template NamedSegment ParseSegment(It start, const It end) { Segment seg; - std::shared_ptr event = *start; + std::shared_ptr event = *start; const auto name = GetNameFromSegmentEvent(event); seg.dur = event->Timestamp(); @@ -82,7 +88,8 @@ NamedSegment ParseSegment(It start, const It end) { while (start < end) { event = *start; - auto ne = std::dynamic_pointer_cast(event); + + auto ne = std::dynamic_pointer_cast(event); if (ne != nullptr) { const auto id = ne->RemoteParty(); @@ -94,7 +101,12 @@ NamedSegment ParseSegment(It start, const It end) { } } - if (event->EventType() == scl::sim::Event::Type::SEGMENT_END) { + if (event->EventType() == sim::Event::Type::CHECKPOINT) { + const auto* ce = dynamic_cast(event.get()); + seg.checkpoints[ce->Id()] = ce->Timestamp(); + } + + if (event->EventType() == sim::Event::Type::SEGMENT_END) { seg.dur = event->Timestamp() - seg.dur; return {name, seg}; } @@ -130,7 +142,7 @@ void UpdateSentRecv(SentRecvMap& m0, const SentRecvMap& m1) { } } -using SegmentMap = std::unordered_map; +using SegmentMap = std::unordered_map; /** * @brief Merge segments by their name. @@ -141,7 +153,7 @@ using SegmentMap = std::unordered_map; SegmentMap MergeSegments(const std::vector& segments) { SegmentMap m; - m[{}].dur = scl::util::Time::Duration::zero(); + m[{}].dur = util::Time::Duration::zero(); for (const auto& named_seg : segments) { const auto name = named_seg.first; @@ -152,6 +164,8 @@ SegmentMap MergeSegments(const std::vector& segments) { } else { m[name].dur += segm.dur; UpdateSentRecv(m[name].sr, segm.sr); + m[name].checkpoints.insert(segm.checkpoints.begin(), + segm.checkpoints.end()); } m[{}].dur += segm.dur; @@ -163,10 +177,11 @@ SegmentMap MergeSegments(const std::vector& segments) { template void ValidateTraceHeadAndTail(It head, It tail) { - if ((*head)->EventType() != scl::sim::Event::Type::START) { + if ((*head)->EventType() != sim::Event::Type::START) { throw std::logic_error("incomplete trace"); } - if ((*tail)->EventType() != scl::sim::Event::Type::STOP) { + const auto last = (*tail)->EventType(); + if (last != sim::Event::Type::STOP && last != sim::Event::Type::KILLED) { throw std::logic_error("truncated trace"); } } @@ -183,9 +198,10 @@ void AppendIfMissing(std::vector& list, /** * @brief Create a result from a list of simulation traces. */ -scl::sim::Result scl::sim::Result::Create( - const std::vector& traces) { +sim::Result sim::Result::Create( + const std::vector& traces) { std::vector segments; + for (const auto& trace : traces) { auto b = trace.begin(); const auto e = trace.end(); @@ -196,9 +212,9 @@ scl::sim::Result scl::sim::Result::Create( // Extract each segment std::vector named_segments; while (b < e) { - std::shared_ptr event = *b; + std::shared_ptr event = *b; - if (event->EventType() == scl::sim::Event::Type::SEGMENT_BEGIN) { + if (event->EventType() == sim::Event::Type::SEGMENT_BEGIN) { named_segments.emplace_back(ParseSegment(b, e)); } @@ -211,11 +227,19 @@ scl::sim::Result scl::sim::Result::Create( std::vector segment_names; std::unordered_map segment_measurements; + std::unordered_map checkpoints; for (const auto& seg_map : segments) { for (const auto& [seg_name, seg] : seg_map) { if (seg_name.has_value()) { - AppendIfMissing(segment_names, seg_name.value()); + // clang-tidy cannot see that we check if seg_name has a value above, so + // disable the linter here to avoid false negatives. + const auto v = seg_name.value(); // NOLINT + AppendIfMissing(segment_names, v); + } + + for (const auto& [s, c] : seg.checkpoints) { + checkpoints[s].AddSample(c); } segment_measurements[seg_name].duration_m.AddSample(seg.dur); @@ -233,21 +257,21 @@ scl::sim::Result scl::sim::Result::Create( } } - return Result(traces, segment_measurements, segment_names); + return Result(traces, segment_measurements, checkpoints, segment_names); } -std::vector scl::sim::Result::Create( - const std::vector>& traces) { +std::vector sim::Result::Create( + const std::vector>& traces) { const auto num_parties = traces[0].size(); - const auto num_iterations = traces.size(); + const auto num_replications = traces.size(); std::vector results; results.reserve(num_parties); for (std::size_t i = 0; i < num_parties; ++i) { std::vector traces_for_party; - traces_for_party.reserve(num_iterations); - for (std::size_t j = 0; j < num_iterations; ++j) { + traces_for_party.reserve(num_replications); + for (std::size_t j = 0; j < num_replications; ++j) { traces_for_party.emplace_back(traces[j][i]); } @@ -271,7 +295,7 @@ std::vector KeySet(const std::unordered_map& map) { } // namespace -std::vector scl::sim::Result::Interactions( +std::vector sim::Result::Interactions( const SegmentName& name) const { return KeySet(m_measurements.at(name).channels_m); } @@ -287,46 +311,41 @@ void WriteSegmentTrace(std::ostream& stream, It start, It end) { } // namespace -void scl::sim::Result::WriteTrace( - std::ostream& stream, - std::size_t iteration, - const scl::sim::Result::SegmentName& name) const { - if (iteration >= m_traces.size()) { - throw std::invalid_argument("invalid iteration"); +void sim::Result::WriteTrace(std::ostream& stream, + std::size_t replication, + const sim::Result::SegmentName& name) const { + if (replication >= m_traces.size()) { + throw std::invalid_argument("invalid replication"); } if (!name.has_value()) { WriteSegmentTrace(stream, - m_traces[iteration].begin(), - m_traces[iteration].end()); + m_traces[replication].begin(), + m_traces[replication].end()); } else { - auto start = m_traces[iteration].begin(); - auto end = m_traces[iteration].end(); - const auto& segment_name = name.value(); - while (start != end) { - const auto seg_ev = std::dynamic_pointer_cast(*start); + bool in_relevant_segment = false; - if (seg_ev != nullptr && - seg_ev->EventType() == Event::Type::SEGMENT_BEGIN && - seg_ev->Name() == segment_name) { - break; + for (const auto& e : m_traces[replication]) { + if (in_relevant_segment) { + stream << e << std::endl; } - start++; - } - auto offset = start + 1; - while (offset != end) { - const auto seg_ev = std::dynamic_pointer_cast(*offset); - if (seg_ev != nullptr && - seg_ev->EventType() == Event::Type::SEGMENT_END && - seg_ev->Name() == segment_name) { - break; + const auto s = std::dynamic_pointer_cast(e); + + if (s != nullptr) { + if (!in_relevant_segment && + s->EventType() == Event::Type::SEGMENT_BEGIN && + s->Name() == segment_name) { + stream << e << std::endl; + in_relevant_segment = true; + } + + if (in_relevant_segment && s->EventType() == Event::Type::SEGMENT_END) { + in_relevant_segment = false; + } } - offset++; } - - WriteSegmentTrace(stream, start, offset + 1); } } @@ -352,7 +371,7 @@ void WriteObj(std::ostream& stream, const long double& val) { stream << val; } -void WriteObj(std::ostream& stream, const scl::util::Time::Duration& d) { +void WriteObj(std::ostream& stream, const util::Time::Duration& d) { auto t = std::chrono::duration(d).count(); WriteObj(stream, t); } @@ -374,37 +393,25 @@ void WriteUnit(std::ostream& stream) { } template <> -void WriteUnit(std::ostream& stream) { +void WriteUnit(std::ostream& stream) { WriteObj(stream, std::string{"milliseconds"}); } template -void WriteObj(std::ostream& stream, const scl::sim::Measurement& m) { +void WriteList(std::ostream& stream, const std::vector& items); + +template +void WriteObj(std::ostream& stream, const sim::Measurement& m) { stream << "{"; - WriteKey(stream, "samples"); - WriteObj(stream, m.Size()); - stream << ","; WriteKey(stream, "unit"); WriteUnit(stream); stream << ","; - WriteKey(stream, "mean"); - WriteObj(stream, m.Mean()); - stream << ","; - WriteKey(stream, "median"); - WriteObj(stream, m.Median()); - stream << ","; - WriteKey(stream, "min"); - WriteObj(stream, m.Min()); - stream << ","; - WriteKey(stream, "max"); - WriteObj(stream, m.Max()); - stream << ","; - WriteKey(stream, "std_dev"); - WriteObj(stream, m.StdDev()); + WriteKey(stream, "samples"); + WriteList(stream, m.Samples()); stream << "}"; } -void WriteObj(std::ostream& stream, const scl::sim::SendRecvMeasurement& srm) { +void WriteObj(std::ostream& stream, const sim::SendRecvMeasurement& srm) { stream << "{"; WriteKey(stream, "sent"); WriteObj(stream, srm.sent); @@ -414,13 +421,12 @@ void WriteObj(std::ostream& stream, const scl::sim::SendRecvMeasurement& srm) { stream << "}"; } -void WriteObj(std::ostream& stream, - const scl::sim::Result::SegmentMeasurement& m) { +void WriteObj(std::ostream& stream, const sim::Result::SegmentMeasurement& m) { stream << "{"; - WriteKey(stream, "execution_time"); + WriteKey(stream, "time"); WriteObj(stream, m.duration_m); stream << ","; - WriteKey(stream, "send_recv"); + WriteKey(stream, "data"); WriteObj(stream, m.send_recv_m); stream << ","; WriteKey(stream, "channels"); @@ -459,7 +465,7 @@ void WriteMap(std::ostream& stream, const std::unordered_map& map) { } // namespace -void scl::sim::Result::Write(std::ostream& stream) const { +void sim::Result::Write(std::ostream& stream) const { stream << "{"; WriteKey(stream, "names"); @@ -468,6 +474,10 @@ void scl::sim::Result::Write(std::ostream& stream) const { WriteKey(stream, "measurements"); WriteMap(stream, m_measurements); + stream << ","; + + WriteKey(stream, "checkpoints"); + WriteMap(stream, m_checkpoints); stream << "}" << std::endl; } diff --git a/src/scl/simulation/simulate_recv_time.cc b/src/scl/simulation/simulate_recv_time.cc index 2eecab3..a7c8d40 100644 --- a/src/scl/simulation/simulate_recv_time.cc +++ b/src/scl/simulation/simulate_recv_time.cc @@ -22,6 +22,8 @@ #include "scl/simulation/config.h" #include "scl/simulation/simulator.h" +using namespace scl; + namespace { /** @@ -43,8 +45,7 @@ long double TransferSizeWithHeadersBits(std::size_t nbytes, /** * @brief Get the RTT from a config in seconds. */ -long double RoundTripTimeSeconds( - const scl::sim::SimulatedNetworkConfig& config) noexcept { +long double RoundTripTimeSeconds(const sim::ChannelConfig& config) noexcept { using namespace std::chrono_literals; const auto d = std::chrono::milliseconds(config.RTT()); return d / 1.0s; @@ -54,7 +55,7 @@ long double RoundTripTimeSeconds( * @brief Compute the maximum TCP throughput assuming package loss of 0% */ long double ThroughputZeroPackageLoss( - const scl::sim::SimulatedNetworkConfig& config) noexcept { + const sim::ChannelConfig& config) noexcept { // Simple throughput formula: // https://tetcos.com/pdf/v13/Experiments/Mathematical-Modelling-of-TCP-Throughput-Performance.pdf const auto rtt = RoundTripTimeSeconds(config); @@ -72,7 +73,7 @@ long double ThroughputZeroPackageLoss( * @brief Compute TCP throughput assuming package loss using Mathis et. al. */ long double ThroughputNonZeroPackageLoss( - const scl::sim::SimulatedNetworkConfig& config) noexcept { + const sim::ChannelConfig& config) noexcept { const auto mss = (long double)config.MSS(); const auto loss_term = std::sqrt(3.0 / (2.0 * config.PackageLoss())); const auto rtt = RoundTripTimeSeconds(config); @@ -80,11 +81,8 @@ long double ThroughputNonZeroPackageLoss( return loss_term * (8 * mss / rtt); } -} // namespace - -scl::util::Time::Duration scl::sim::ComputeRecvTime( - const SimulatedNetworkConfig& config, - std::size_t n) { +util::Time::Duration ComputeRecvTimeTcp(const sim::ChannelConfig& config, + std::size_t n) { const auto total_size_bits = TransferSizeWithHeadersBits(n, config.MSS()); auto actual_tp = ThroughputZeroPackageLoss(config); @@ -97,3 +95,14 @@ scl::util::Time::Duration scl::sim::ComputeRecvTime( const auto t_sec = std::chrono::duration(t); return std::chrono::duration_cast(t_sec); } + +} // namespace + +util::Time::Duration sim::ComputeRecvTime(const ChannelConfig& config, + std::size_t n) { + if (config.Type() == sim::ChannelConfig::NetworkType::TCP) { + return ComputeRecvTimeTcp(config, n); + } + // sim::ChannelConfig::NetworkType::INSTANT + return util::Time::Duration::zero(); +} diff --git a/src/scl/simulation/simulator.cc b/src/scl/simulation/simulator.cc index 28cb9fd..f043e2e 100644 --- a/src/scl/simulation/simulator.cc +++ b/src/scl/simulation/simulator.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -40,29 +41,39 @@ #include "scl/simulation/result.h" #include "scl/util/time.h" +using namespace scl; + namespace { -/** - * @brief Create an Event of some type and duration. - */ -auto CreateEvent(scl::sim::Event::Type t, scl::util::Time::Duration d) { - return std::make_shared(t, d); +auto CreateEvent(sim::Event::Type t, util::Time::Duration d) { + return std::make_shared(t, d); +} + +auto CreateSegmentEvent(util::Time::Duration t, + const std::string& n, + bool is_end) { + if (is_end) { + return std::make_shared(sim::Event::Type::SEGMENT_END, + t, + n); + } + return std::make_shared(sim::Event::Type::SEGMENT_BEGIN, + t, + n); } -std::vector CreateNetworks( - std::shared_ptr ctx) { - std::vector networks; +auto CreateNetworks(std::shared_ptr ctx) { + std::vector networks; const auto n = ctx->NumberOfParties(); networks.reserve(n); for (std::size_t i = 0; i < n; ++i) { - std::vector> channels; + std::vector> channels; channels.reserve(n); for (std::size_t j = 0; j < n; ++j) { - scl::sim::ChannelId cid(i, j); - channels.emplace_back( - std::make_shared(cid, ctx)); + sim::ChannelId cid(i, j); + channels.emplace_back(std::make_shared(cid, ctx)); } networks.emplace_back(channels, i); @@ -71,98 +82,101 @@ std::vector CreateNetworks( return networks; } // LCOV_EXCL_LINE -/** - * @brief Create a SEGMENT_END or SEGMENT_BEGIN event. - */ -std::shared_ptr CreateSegmentEvent(scl::util::Time::Duration t, - const std::string& n, - bool is_end) { - if (is_end) { - return std::make_shared( - scl::sim::Event::Type::SEGMENT_END, - t, - n); +struct RunResult { + std::unique_ptr next; + std::any output; +}; + +RunResult Run(std::shared_ptr ctx, + std::size_t id, + proto::Protocol* protocol, + proto::Env& env) { + RunResult result; + + if (ctx->Trace(id).empty()) { + ctx->AddEvent( + id, + CreateEvent(sim::Event::Type::START, util::Time::Duration::zero())); } - return std::make_shared( - scl::sim::Event::Type::SEGMENT_BEGIN, - t, - n); -} -/** - * @brief Run a protocol step for a party. - */ -std::unique_ptr Run( - std::shared_ptr ctx, - std::size_t party_id, - scl::proto::Protocol* party, - scl::proto::Env& env, - const scl::sim::OutputCallback& output_callback) { - if (ctx->Trace(party_id).empty()) { - ctx->AddEvent(party_id, - CreateEvent(scl::sim::Event::Type::START, - scl::util::Time::Duration::zero())); + if (protocol == nullptr) { + // handling of entries which are null. + ctx->AddEvent( + id, + CreateEvent(sim::Event::Type::STOP, util::Time::Duration::zero())); + return result; } ctx->AddEvent( - party_id, - CreateSegmentEvent(ctx->LatestTimestamp(party_id), party->Name(), false)); + id, + CreateSegmentEvent(ctx->LatestTimestamp(id), protocol->Name(), false)); ctx->UpdateCheckpoint(); - auto next = party->Run(env); - const auto exec_time = ctx->Checkpoint(party_id); - - const auto output = party->Output(); - if (output.has_value()) { - ctx->AddEvent(party_id, - CreateEvent(scl::sim::Event::Type::OUTPUT, exec_time)); - output_callback(party_id, output); + result.next = protocol->Run(env); + const auto exec_time = ctx->Checkpoint(id); + + result.output = protocol->Output(); + + if (result.output.has_value()) { + ctx->AddEvent(id, CreateEvent(sim::Event::Type::OUTPUT, exec_time)); } - ctx->AddEvent(party_id, CreateSegmentEvent(exec_time, party->Name(), true)); + ctx->AddEvent(id, CreateSegmentEvent(exec_time, protocol->Name(), true)); - if (next == nullptr) { - ctx->AddEvent(party_id, - CreateEvent(scl::sim::Event::Type::STOP, exec_time)); + if (result.next == nullptr) { + ctx->AddEvent(id, CreateEvent(sim::Event::Type::STOP, exec_time)); } - return next; + return result; } -/** - * @brief Run a simulation. - */ -std::vector RunSimulation( - std::vector> protocols, - const scl::sim::SimulatedNetworkConfigCreator& config_creator, - const scl::sim::OutputCallback& output_callback) { - const auto n = protocols.size(); - auto ps = std::move(protocols); - - auto ctx = - scl::sim::SimulationContext::Create( - n, - config_creator); +std::vector CreateEnvs(const std::vector& networks, + std::shared_ptr ctx) { + std::vector envs; + envs.reserve(ctx->NumberOfParties()); + for (std::size_t i = 0; i < ctx->NumberOfParties(); ++i) { + envs.emplace_back(proto::Env{networks[i], + std::make_unique(ctx, i), + std::make_unique(ctx, i)}); + } + return envs; +} + +auto RunSimulation(std::size_t replication, sim::Manager* manager) { + auto ps = manager->Protocol(); + auto ctx = sim::Context::Create( + ps.size(), + manager->NetworkConfiguration()); auto networks = CreateNetworks(ctx); + auto envs = CreateEnvs(networks, ctx); auto next_id = ctx->NextToRun(); + while (next_id.has_value()) { auto id = next_id.value(); try { ctx->Prepare(id); - scl::proto::Env env{ - networks[id], - std::make_unique(ctx, id), - std::make_unique(ctx, id)}; + auto result = Run(ctx, id, ps[id].get(), envs[id]); + + ps[id] = std::move(result.next); - ps[id] = Run(ctx, id, ps[id].get(), env, output_callback); + if (result.output.has_value()) { + manager->HandleOutput(replication, id, result.output); + } + + if (ps[id] != nullptr && manager->Terminate(id, ctx->GetView())) { + ps[id] = nullptr; + ctx->AddEvent( + id, + CreateEvent(sim::Event::Type::KILLED, ctx->LatestTimestamp(id))); + } ctx->Commit(id); - } catch (scl::sim::SimulationFailure& e) { + } catch (sim::SimulationFailure& e) { ctx->Rollback(id); } @@ -174,26 +188,12 @@ std::vector RunSimulation( } // namespace -std::vector scl::sim::Simulate( - const ProtocolCreator& protocol_creator, - const SimulatedNetworkConfigCreator& config_creator, - std::size_t iterations, - const OutputCallback& output_cb) { +std::vector sim::Simulate(std::unique_ptr manager) { std::vector> traces; - for (std::size_t i = 0; i < iterations; ++i) { - traces.emplace_back( - RunSimulation(protocol_creator(), config_creator, output_cb)); + auto network_conf = manager->NetworkConfiguration(); + for (std::size_t i = 0; i < manager->Replications(); ++i) { + traces.emplace_back(RunSimulation(i, manager.get())); } return Result::Create(traces); } - -std::vector scl::sim::Simulate( - std::vector> parties, - const SimulatedNetworkConfigCreator& config_creator, - const OutputCallback& output_cb) { - std::vector> traces; - traces.emplace_back( - RunSimulation(std::move(parties), config_creator, output_cb)); - return Result::Create(traces); -} diff --git a/test/scl/math/fields.h b/test/scl/math/fields.h index 5b3da54..3d4f9ac 100644 --- a/test/scl/math/fields.h +++ b/test/scl/math/fields.h @@ -27,7 +27,7 @@ using GF7 = math::FF; #ifdef SCL_ENABLE_EC_TESTS using Secp256k1_Field = math::FF; -using Secp256k1_Order = math::FF; +using Secp256k1_Order = math::FF; #endif } // namespace scl::test diff --git a/test/scl/math/test_ff.cc b/test/scl/math/test_ff.cc index 0d15fa5..2c9e24a 100644 --- a/test/scl/math/test_ff.cc +++ b/test/scl/math/test_ff.cc @@ -131,6 +131,7 @@ TEMPLATE_TEST_CASE("FF Multiplication", "[math][ff]", FIELD_DEFS) { REPEAT { auto a = RandomNonZero(prg); auto b = RandomNonZero(prg); + REQUIRE(a * b != zero); REQUIRE(a * b == b * a); auto c = RandomNonZero(prg); REQUIRE(c * (a + b) == c * a + c * b); @@ -191,3 +192,18 @@ TEMPLATE_TEST_CASE("FF serialization", "[math][ff]", FIELD_DEFS) { REQUIRE(a == b); } } + +TEMPLATE_TEST_CASE("FF Exp", "[math][ff]", FIELD_DEFS) { + using FF = TestType; + + auto prg = util::PRG::Create("FF exp"); + + auto a = RandomNonZero(prg); + + REQUIRE(a == Exp(a, 1)); + REQUIRE(a * a == Exp(a, 2)); + + REQUIRE(a * a * a * a * a * a == Exp(a, 6)); + + REQUIRE(FF::One() == Exp(a, 0)); +} diff --git a/test/scl/math/test_poly.cc b/test/scl/math/test_poly.cc index 2eaed81..e038965 100644 --- a/test/scl/math/test_poly.cc +++ b/test/scl/math/test_poly.cc @@ -44,6 +44,7 @@ TEMPLATE_TEST_CASE("Polynomial construct", "[ss][math]", FIELD_DEFS) { REQUIRE(x[0] == FF(1)); REQUIRE(x[1] == FF(2)); REQUIRE(x[2] == FF(6)); + REQUIRE(x.Coefficients() == coeff); math::Vec with_zeros = {FF(1), FF(0), FF(3), FF(0)}; auto y = math::Polynomial::Create(with_zeros); diff --git a/test/scl/math/test_secp256k1.cc b/test/scl/math/test_secp256k1.cc index caa5907..6e2f2c1 100644 --- a/test/scl/math/test_secp256k1.cc +++ b/test/scl/math/test_secp256k1.cc @@ -21,6 +21,7 @@ #include "scl/math/curves/secp256k1.h" #include "scl/math/ec_ops.h" +#include "scl/math/ff.h" #include "scl/math/fp.h" #include "scl/math/number.h" #include "scl/util/prg.h" @@ -28,7 +29,7 @@ using namespace scl; using Curve = math::EC; -using Scalar = Curve::Order; +using Scalar = Curve::ScalarField; using Field = Curve::Field; namespace { @@ -126,8 +127,7 @@ TEST_CASE("Secp256k1 generator", "[math][ec]") { ss << g; REQUIRE(ss.str() == g.ToString()); - auto ord = math::Number::FromString( - "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"); + auto ord = math::Order(); REQUIRE(!g.PointAtInfinity()); auto poi = g * ord; @@ -202,13 +202,6 @@ TEST_CASE("Secp256k1 negation special case", "[math][ec]") { REQUIRE(math::CurveIsPointAtInfinity(point)); } -TEST_CASE("Secp256k1 double point special case", "[math][ec]") { - using CurveT = math::Secp256k1; - CurveT::ValueType point = {Field(1), Field(0), Field(1)}; - math::CurveDouble(point); - REQUIRE(math::CurveIsPointAtInfinity(point)); -} - TEST_CASE("Secp256k1 serialization", "[math][ec]") { auto prg = util::PRG::Create(); @@ -259,3 +252,11 @@ TEST_CASE("Secp256k1 serialization", "[math][ec]") { auto j = Curve::Read(buffer.get()); REQUIRE(i == j); } + +TEST_CASE("Secp256k1 order", "[math]") { + auto ord = math::Order(); + REQUIRE( + ord == + math::Number::FromString( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F")); +} diff --git a/test/scl/math/test_vec.cc b/test/scl/math/test_vec.cc index f669b1e..b515df2 100644 --- a/test/scl/math/test_vec.cc +++ b/test/scl/math/test_vec.cc @@ -21,9 +21,12 @@ #include #include +#include "scl/math/curves/secp256k1.h" +#include "scl/math/ec.h" #include "scl/math/fp.h" #include "scl/math/mat.h" #include "scl/math/vec.h" +#include "scl/util/traits.h" using namespace scl; @@ -170,3 +173,25 @@ TEST_CASE("Vector sub vector", "[math][la]") { std::logic_error, Catch::Matchers::Message("invalid range")); } + +TEST_CASE("Vector scalar EC", "[math]") { + using Curve = math::EC; + + auto v = math::Vec{Curve::Generator(), + Curve::Generator(), + Curve::Generator()}; + + const auto s = Curve::ScalarField(123); + auto w = v.ScalarMultiply(s); + + REQUIRE(w[0] == Curve::Generator() * s); + REQUIRE(w[1] == Curve::Generator() * s); + REQUIRE(w[2] == Curve::Generator() * s); + + const auto z = math::Number(123); + auto u = w.ScalarMultiply(math::Number(123)); + + REQUIRE(u[0] == w[0] * z); + REQUIRE(u[1] == w[1] * z); + REQUIRE(u[2] == w[2] * z); +} diff --git a/test/scl/simulation/test_channel.cc b/test/scl/simulation/test_channel.cc new file mode 100644 index 0000000..4079323 --- /dev/null +++ b/test/scl/simulation/test_channel.cc @@ -0,0 +1,200 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include + +#include "scl/simulation/channel.h" +#include "scl/simulation/config.h" +#include "scl/simulation/context.h" +#include "scl/simulation/event.h" +#include "scl/simulation/mem_channel_buffer.h" +#include "scl/simulation/simulator.h" + +using namespace scl; + +namespace { + +struct InstantNetworkConfig final : sim::NetworkConfig { + sim::ChannelConfig Get(sim::ChannelId channel_id) override { + (void)channel_id; + return sim::ChannelConfig::Loopback(); + } +}; + +auto StartEvent(util::Time::Duration ts) { + return std::make_shared(sim::Event::Type::START, ts); +} + +auto StopEvent(util::Time::Duration ts) { + return std::make_shared(sim::Event::Type::STOP, ts); +} + +} // namespace + +TEST_CASE("Channel recv packet blocking", "[sim]") { + auto cfg = std::make_shared(); + auto ctx = sim::Context::Create(2, cfg); + auto chl0 = sim::Channel({0, 1}, ctx); + auto chl1 = sim::Channel({1, 0}, ctx); + + net::Packet p; + p << 123; + ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); + chl0.Send(p); + const auto t0 = ctx->Trace(0); + REQUIRE(t0.size() == 2); + REQUIRE(t0[0]->EventType() == sim::Event::Type::START); + REQUIRE(t0[1]->EventType() == sim::Event::Type::PACKET_SEND); + + ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); + chl1.Recv(); + + const auto t1 = ctx->Trace(1); + REQUIRE(t1.size() == 2); + REQUIRE(t1[0]->EventType() == sim::Event::Type::START); + REQUIRE(t1[1]->EventType() == sim::Event::Type::PACKET_RECV); +} + +TEST_CASE("Channel recv packet non-blocking", "[sim]") { + auto cfg = std::make_shared(); + auto ctx = sim::Context::Create(2, cfg); + auto chl0 = sim::Channel({0, 1}, ctx); + auto chl1 = sim::Channel({1, 0}, ctx); + + net::Packet p; + p << 123; + ctx->AddEvent(0, StartEvent(util::Time::Duration(1000))); + chl0.Send(p); + + ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); + auto pkt = chl1.Recv(false); + + REQUIRE_FALSE(pkt.has_value()); + auto t0 = ctx->Trace(1); + REQUIRE(t0.size() == 2); + REQUIRE(t0[0]->EventType() == sim::Event::Type::START); + REQUIRE(t0[1]->EventType() == sim::Event::Type::PACKET_RECV); + + ctx->AddEvent(1, StartEvent(ctx->LatestTimestamp(0))); + auto pkt0 = chl1.Recv(false); + + REQUIRE(pkt0.has_value()); + t0 = ctx->Trace(1); + REQUIRE(t0.size() == 4); + REQUIRE(t0[2]->EventType() == sim::Event::Type::START); + REQUIRE(t0[3]->EventType() == sim::Event::Type::PACKET_RECV); +} + +TEST_CASE("Channel recv chunked", "[sim]") { + auto cfg = std::make_shared(); + auto ctx = sim::Context::Create(2, cfg); + auto chl0 = sim::Channel({0, 1}, ctx); + auto chl1 = sim::Channel({1, 0}, ctx); + + unsigned char data[] = {1, 2, 3, 4}; + ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); + chl0.Send(data, 4); + + ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); + unsigned char recv[4] = {0}; + + REQUIRE(ctx->HasWrite({0, 1})); + REQUIRE(ctx->NextWrite({0, 1}).amount == 4); + chl1.Recv(recv, 2); + + REQUIRE(ctx->NextWrite({0, 1}).amount == 2); + chl1.Recv(recv + 2, 2); + REQUIRE_FALSE(ctx->HasWrite({0, 1})); + + REQUIRE(data[0] == recv[0]); + REQUIRE(data[1] == recv[1]); + REQUIRE(data[2] == recv[2]); + REQUIRE(data[3] == recv[3]); +} + +TEST_CASE("Channel HasData no data, but not far ahead", "[sim]") { + auto cfg = std::make_shared(); + auto ctx = sim::Context::Create(2, cfg); + + sim::Channel p0({0, 1}, ctx); + sim::Channel p1({1, 0}, ctx); + + // P1 at time 100000, P0 at time 0. So we can say for sure that P1 does not + // have data for P0. + + ctx->AddEvent(1, StartEvent(util::Time::Duration(100000))); + ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); + + ctx->UpdateCheckpoint(); + auto hd = p0.HasData(); + REQUIRE_FALSE(hd); +} + +TEST_CASE("Channel HasData no data, other party terminated", "[sim]") { + auto cfg = std::make_shared(); + auto ctx = sim::Context::Create(2, cfg); + + sim::Channel p0({0, 1}, ctx); + sim::Channel p1({1, 0}, ctx); + + ctx->AddEvent(1, StopEvent(util::Time::Duration::zero())); + ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); + + ctx->UpdateCheckpoint(); + auto hd = p0.HasData(); + REQUIRE_FALSE(hd); +} + +TEST_CASE("Channel HasData no data, fails", "[sim]") { + auto cfg = std::make_shared(); + auto ctx = sim::Context::Create(3, cfg); + + sim::Channel p0({0, 1}, ctx); + sim::Channel p1({1, 0}, ctx); + + // P1 at time 100000, P0 at time 0. So we can say for sure that P1 does not + // have data for P0. + + ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); + ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); + + ctx->UpdateCheckpoint(); + REQUIRE_THROWS_MATCHES(p0.HasData(), + sim::SimulationFailure, + Catch::Matchers::Message("no data, and we're ahead")); + auto next = ctx->NextToRun(0); + REQUIRE(next.value_or(-1) == 1); +} + +TEST_CASE("Channel HasData other party not started", "[sim]") { + auto cfg = std::make_shared(); + auto ctx = sim::Context::Create(3, cfg); + + sim::Channel p0({0, 1}, ctx); + sim::Channel p1({1, 0}, ctx); + + ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); + + ctx->UpdateCheckpoint(); + REQUIRE_THROWS_MATCHES( + p0.HasData(), + sim::SimulationFailure, + Catch::Matchers::Message("other party hasnt started yet")); + auto next = ctx->NextToRun(0); + REQUIRE(next.value_or(-1) == 1); +} diff --git a/test/scl/simulation/test_config.cc b/test/scl/simulation/test_config.cc index 30140ac..daac6ce 100644 --- a/test/scl/simulation/test_config.cc +++ b/test/scl/simulation/test_config.cc @@ -20,57 +20,110 @@ #include #include "scl/simulation/config.h" +#include "scl/simulation/simulator.h" using namespace scl; +using namespace std::chrono_literals; + +namespace { + +template +void ApproxDuration(util::Time::Duration d, T v, T b) { + if (v > d) { + REQUIRE(v - d <= b); + } else { + REQUIRE(d - v <= b); + } +} + +std::size_t KB(std::size_t bytes) { + return 1000 * bytes; +} + +std::size_t MB(std::size_t bytes) { + return 1000 * KB(bytes); +} + +} // namespace + +TEST_CASE("ComputeRecvTime default config", "[sim]") { + // https://wintelguy.com/wanperf.pl + // parameters: + // Link bandwidth (Mbit/s): 1 + // RTT (millisecond): 100 + // Packet loss (%): 0 + // MTU (Byte): 1500 + // L1/L2 frame overhead (Byte): 0 <-- not accounted for in scl + // TCP/IP (v4) header overhead (Byte): 40 + // TCP window (RWND) size (Byte): 65536 + // File size (MByte): 1 + + const auto cfg = sim::ChannelConfig::Default(); + const auto tenMB = MB(10); + const auto t = sim::ComputeRecvTime(cfg, tenMB); + ApproxDuration(t, 82s, 1s); +} + +TEST_CASE("ComputeRecvTime lossy", "[sim]") { + const auto cfg = sim::ChannelConfig::Builder().PackageLoss(0.001).Build(); + const auto tenMB = MB(10); + const auto t = sim::ComputeRecvTime(cfg, tenMB); + ApproxDuration(t, 82s, 1s); +} + +TEST_CASE("ComputeRecvTime lo", "[sim]") { + const auto cfg = sim::ChannelConfig::Loopback(); + const auto amount = MB(10000); + const auto t = sim::ComputeRecvTime(cfg, amount); + REQUIRE(t.count() == 0); +} TEST_CASE("SimulationConfig default", "[sim]") { - auto cfg = sim::SimulatedNetworkConfig::Default(); - - REQUIRE(cfg.Bandwidth() == sim::SimulatedNetworkConfig::kDefaultBandwidth); - REQUIRE(cfg.RTT() == sim::SimulatedNetworkConfig::kDefaultRTT); - REQUIRE(cfg.MSS() == sim::SimulatedNetworkConfig::kDefaultMSS); - REQUIRE(cfg.PackageLoss() == - sim::SimulatedNetworkConfig::kDefaultPackageLoss); - REQUIRE(cfg.WindowSize() == sim::SimulatedNetworkConfig::kDefaultWindowSize); + auto cfg = sim::ChannelConfig::Default(); + + REQUIRE(cfg.Bandwidth() == sim::ChannelConfig::DEFAULT_BANDWIDTH); + REQUIRE(cfg.RTT() == sim::ChannelConfig::DEFAULT_RTT); + REQUIRE(cfg.MSS() == sim::ChannelConfig::DEFAULT_MSS); + REQUIRE(cfg.PackageLoss() == sim::ChannelConfig::DEFAULT_PACKAGE_LOSS); + REQUIRE(cfg.WindowSize() == sim::ChannelConfig::DEFAULT_WINDOW_SIZE); } TEST_CASE("SimulationConfig setters", "[sim]") { - auto cfg_it = sim::SimulatedNetworkConfig::Builder{}.MSS(5000).Build(); + auto cfg_it = sim::ChannelConfig::Builder{}.MSS(5000).Build(); REQUIRE(cfg_it.MSS() == 5000); - REQUIRE(cfg_it.Bandwidth() == sim::SimulatedNetworkConfig::kDefaultBandwidth); + REQUIRE(cfg_it.Bandwidth() == sim::ChannelConfig::DEFAULT_BANDWIDTH); // Assume rest of properties are also defaulted correctly. } TEST_CASE("SimulationConfig validation", "[sim]") { - REQUIRE_THROWS_MATCHES( - sim::SimulatedNetworkConfig::Builder{}.Bandwidth(0).Build(), - std::invalid_argument, - Catch::Matchers::Message("bandwidth cannot be 0")); + REQUIRE_THROWS_MATCHES(sim::ChannelConfig::Builder{}.Bandwidth(0).Build(), + std::invalid_argument, + Catch::Matchers::Message("bandwidth cannot be 0")); - REQUIRE_THROWS_MATCHES(sim::SimulatedNetworkConfig::Builder{}.MSS(0).Build(), + REQUIRE_THROWS_MATCHES(sim::ChannelConfig::Builder{}.MSS(0).Build(), std::invalid_argument, Catch::Matchers::Message("MSS cannot be 0")); REQUIRE_THROWS_MATCHES( - sim::SimulatedNetworkConfig::Builder{}.PackageLoss(-0.1).Build(), + sim::ChannelConfig::Builder{}.PackageLoss(-0.1).Build(), std::invalid_argument, Catch::Matchers::Message("package loss percentage cannot be negative")); REQUIRE_THROWS_MATCHES( - sim::SimulatedNetworkConfig::Builder{}.PackageLoss(1).Build(), + sim::ChannelConfig::Builder{}.PackageLoss(1).Build(), std::invalid_argument, Catch::Matchers::Message("package loss percentage cannot exceed 100%")); REQUIRE_THROWS_MATCHES( - sim::SimulatedNetworkConfig::Builder{}.WindowSize(0).Build(), + sim::ChannelConfig::Builder{}.WindowSize(0).Build(), std::invalid_argument, Catch::Matchers::Message("TCP window size cannot be 0")); } TEST_CASE("SimulationConfig to string", "[sim]") { std::stringstream ss; - auto cfg = sim::SimulatedNetworkConfig::Builder{} + auto cfg = sim::ChannelConfig::Builder{} .Bandwidth(2) .MSS(10) .RTT(50) @@ -80,6 +133,7 @@ TEST_CASE("SimulationConfig to string", "[sim]") { ss << cfg; // clang-format off REQUIRE(ss.str() == "SimulationConfig{" + "Type: TCP, " "Bandwidth: 2 bits/s, " "RTT: 50 ms, " "MSS: 10 bytes, " @@ -87,3 +141,11 @@ TEST_CASE("SimulationConfig to string", "[sim]") { "WindowSize: 500 bytes}"); // clang-format on } + +TEST_CASE("SimulationConfig local", "[sim]") { + auto cfg = sim::ChannelConfig::Loopback(); + REQUIRE(cfg.Type() == sim::ChannelConfig::NetworkType::INSTANT); + std::stringstream ss; + ss << cfg; + REQUIRE(ss.str() == "SimulationConfig{INSTANT}"); +} diff --git a/test/scl/simulation/test_context.cc b/test/scl/simulation/test_context.cc index 9d2da4e..01d9a89 100644 --- a/test/scl/simulation/test_context.cc +++ b/test/scl/simulation/test_context.cc @@ -36,12 +36,16 @@ auto SomeEvent() { util::Time::Duration::zero()); } +auto DefaultNetworkConfig() { + return std::make_shared(); +} + } // namespace TEST_CASE("Simulation context add events", "[sim]") { - auto ctx = sim::SimulationContext::Create( + auto ctx = sim::Context::Create( 5, - sim::DefaultConfigCreator()); + DefaultNetworkConfig()); ctx->AddEvent(2, SomeEvent()); ctx->AddEvent(2, SomeEvent()); @@ -55,9 +59,9 @@ TEST_CASE("Simulation context add events", "[sim]") { } TEST_CASE("Simulation context total run time", "[sim]") { - auto ctx = sim::SimulationContext::Create( + auto ctx = sim::Context::Create( 5, - sim::DefaultConfigCreator()); + DefaultNetworkConfig()); ctx->AddEvent(0, SomeEvent()); auto t0 = ctx->Checkpoint(0); @@ -71,13 +75,15 @@ TEST_CASE("Simulation context total run time", "[sim]") { namespace { struct DummyChannelBuffer final : public sim::ChannelBuffer { - std::vector Read(std::size_t n) override { + void Read(unsigned char* data, std::size_t n) override { + (void)data; (void)n; throw std::logic_error("not supported"); } - void Write(const std::vector& data) override { + void Write(const unsigned char* data, std::size_t n) override { (void)data; + (void)n; throw std::logic_error("not supported"); } @@ -107,11 +113,10 @@ struct DummyChannelBuffer final : public sim::ChannelBuffer { namespace scl { template <> -std::shared_ptr -sim::SimulationContext::Create( +std::shared_ptr sim::Context::Create( std::size_t n, - const sim::SimulatedNetworkConfigCreator& config_creator) { - auto ctx = std::make_shared(config_creator); + std::shared_ptr config) { + auto ctx = std::make_shared(config); ctx->m_nparties = n; ctx->m_traces.resize(n); @@ -132,9 +137,8 @@ sim::SimulationContext::Create( (ctx)->Buffer(sim::ChannelId((i), (j)))) TEST_CASE("Simulation context prepare-commit-rollback", "[sim]") { - auto ctx = sim::SimulationContext::Create( - 5, - sim::DefaultConfigCreator()); + auto ctx = + sim::Context::Create(5, DefaultNetworkConfig()); ctx->Prepare(0); @@ -181,9 +185,8 @@ TEST_CASE("Simulation context prepare-commit-rollback", "[sim]") { } TEST_CASE("Simulation context invalid prepare-commit-rollback", "[sim]") { - auto ctx = sim::SimulationContext::Create( - 5, - sim::DefaultConfigCreator()); + auto ctx = + sim::Context::Create(5, DefaultNetworkConfig()); REQUIRE_THROWS_MATCHES(ctx->Commit(0), std::logic_error, @@ -214,9 +217,8 @@ auto StartEvent() { } // namespace TEST_CASE("Simulation context NextToRun simple", "[sim]") { - auto ctx = sim::SimulationContext::Create( - 3, - sim::DefaultConfigCreator()); + auto ctx = + sim::Context::Create(3, DefaultNetworkConfig()); // First party to run is always party 0 auto next = ctx->NextToRun(); @@ -249,10 +251,10 @@ TEST_CASE("Simulation context NextToRun simple", "[sim]") { } TEST_CASE("Simulation context NextToRun fails", "[sim]") { - auto ctx = sim::SimulationContext::Create( - 3, - sim::DefaultConfigCreator()); + auto ctx = + sim::Context::Create(3, DefaultNetworkConfig()); + // 0 running auto next = ctx->NextToRun(); ctx->Prepare(0); @@ -261,10 +263,13 @@ TEST_CASE("Simulation context NextToRun fails", "[sim]") { ctx->Rollback(0); + // 2 running next = ctx->NextToRun(next); - next = ctx->NextToRun(next); - REQUIRE(next.has_value()); - REQUIRE(next.value() == 2); + if (!next.has_value()) { + FAIL("no output"); + } else { + REQUIRE(next.value() == 2); + } ctx->Prepare(2); @@ -294,23 +299,22 @@ TEST_CASE("Simulation context NextToRun fails", "[sim]") { } TEST_CASE("Simulation context rollback write ops", "[sim]") { - auto ctx = sim::SimulationContext::Create( - 3, - sim::DefaultConfigCreator()); + auto ctx = + sim::Context::Create(3, DefaultNetworkConfig()); const auto ts = util::Time::Duration::zero(); // party 0 sends to party 1 ctx->Prepare(0); - ctx->RecordWrite({0, 1}, 10, ts); + ctx->AddWrite({0, 1}, 10, ts); ctx->Commit(0); // party 1 receives data from party 0, but then performs a rollback. ctx->Prepare(1); - REQUIRE(ctx->Writes({0, 1})[0].amount == 10); - ctx->Writes({0, 1})[0].amount = 0; + REQUIRE(ctx->NextWrite({0, 1}).amount == 10); + ctx->NextWrite({0, 1}).amount = 0; ctx->Rollback(1); // the change to the write op above should be undone by the rollback. - REQUIRE(ctx->Writes({0, 1})[0].amount == 10); + REQUIRE(ctx->NextWrite({0, 1}).amount == 10); } diff --git a/test/scl/simulation/test_env.cc b/test/scl/simulation/test_env.cc index 87b68cb..3019350 100644 --- a/test/scl/simulation/test_env.cc +++ b/test/scl/simulation/test_env.cc @@ -37,15 +37,19 @@ auto SomeEvent(util::Time::Duration t) { return std::make_shared(sim::Event::Type::START, t); } +auto DefaultNetworkConfig() { + return std::make_shared(); +} + } // namespace TEST_CASE("Simulation env clock", "[sim]") { using namespace std::chrono_literals; - auto ctx = sim::SimulationContext::Create( + auto ctx = sim::Context::Create( 5, - sim::DefaultConfigCreator()); - sim::SimulatedClock clock(ctx, 0); + DefaultNetworkConfig()); + sim::Clock clock(ctx, 0); ctx->AddEvent(0, SomeEvent()); ctx->UpdateCheckpoint(); @@ -71,10 +75,10 @@ TEST_CASE("Simulation env clock", "[sim]") { TEST_CASE("Simulation env clock checkpoint", "[sim]") { using namespace std::chrono_literals; - auto ctx = sim::SimulationContext::Create( + auto ctx = sim::Context::Create( 5, - sim::DefaultConfigCreator()); - sim::SimulatedClock clock(ctx, 0); + DefaultNetworkConfig()); + sim::Clock clock(ctx, 0); ctx->AddEvent(0, SomeEvent(10ms)); ctx->UpdateCheckpoint(); @@ -84,19 +88,19 @@ TEST_CASE("Simulation env clock checkpoint", "[sim]") { REQUIRE(ctx->Trace(0).back()->Timestamp() >= 10ms); sim::CheckpointEvent* e = (sim::CheckpointEvent*)ctx->Trace(0).back().get(); - REQUIRE(e->Message() == "asd"); + REQUIRE(e->Id() == "asd"); } TEST_CASE("Simulation env thread", "[sim]") { using namespace std::chrono_literals; - auto ctx = sim::SimulationContext::Create( + auto ctx = sim::Context::Create( 5, - sim::DefaultConfigCreator()); + DefaultNetworkConfig()); ctx->UpdateCheckpoint(); - sim::SimulatedThreadCtx thread(ctx, 0); - sim::SimulatedClock clock(ctx, 0); + sim::ThreadCtx thread(ctx, 0); + sim::Clock clock(ctx, 0); ctx->AddEvent(0, SomeEvent(util::Time::Duration(1000ms))); thread.Sleep(1000000); diff --git a/test/scl/simulation/test_event.cc b/test/scl/simulation/test_event.cc index 36ad65b..e35dd26 100644 --- a/test/scl/simulation/test_event.cc +++ b/test/scl/simulation/test_event.cc @@ -126,6 +126,11 @@ TEST_CASE("Simulation Event", "[sim]") { REQUIRE(ToString(&e) == "CLOSE at 0 ms [Local=2, Remote=5]"); } + SECTION("KILLED") { + sim::Event e(sim::Event::Type::KILLED, util::Time::Duration::zero()); + REQUIRE(ToString(&e) == "KILLED at 0 ms"); + } + SECTION("CHECKPOINT") { sim::CheckpointEvent e(util::Time::Duration::zero(), "asd"); REQUIRE(ToString(&e) == "CHECKPOINT at 0 ms [asd]"); diff --git a/test/scl/simulation/test_manager.cc b/test/scl/simulation/test_manager.cc new file mode 100644 index 0000000..b87dc98 --- /dev/null +++ b/test/scl/simulation/test_manager.cc @@ -0,0 +1,58 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include + +#include "scl/simulation/config.h" +#include "scl/simulation/context.h" +#include "scl/simulation/manager.h" +#include "scl/simulation/mem_channel_buffer.h" + +using namespace scl; + +TEST_CASE("SingleReplicationManager", "[sim]") { + sim::SingleReplicationManager m({}); + + auto p = m.Protocol(); + REQUIRE(p.empty()); + + REQUIRE_THROWS_MATCHES( + m.Protocol(), + std::logic_error, + Catch::Matchers::Message( + "Protocol called twice on SingleReplicationManager")); +} + +struct DummyManager final : public sim::Manager { + DummyManager() : sim::Manager(1) {} + std::vector> Protocol() override { + return {}; + } +}; + +TEST_CASE("Default Manager methods", "[sim]") { + DummyManager m; + + auto ctx = sim::Context::Create(1, nullptr); + REQUIRE_FALSE(m.Terminate(0, ctx->GetView())); + + // checks that the config returned is of type SimpleNetworkConfig. + auto p = std::dynamic_pointer_cast( + m.NetworkConfiguration()); + REQUIRE(p != nullptr); +} diff --git a/test/scl/simulation/test_measurement.cc b/test/scl/simulation/test_measurement.cc index 6c914c9..43f05a8 100644 --- a/test/scl/simulation/test_measurement.cc +++ b/test/scl/simulation/test_measurement.cc @@ -50,11 +50,8 @@ TEST_CASE("Measurement data", "[sim]") { dm.AddSample(7); dm.AddSample(9); - REQUIRE(dm.Min() == 2); - REQUIRE(dm.Max() == 9); - REQUIRE(dm.Mean() == 5); - REQUIRE(dm.Median() == 5); - REQUIRE(dm.StdDev() == 2); + REQUIRE(dm.Size() == 8); + REQUIRE(dm.Samples() == std::vector({2, 4, 4, 4, 5, 5, 7, 9})); } TEST_CASE("Measurement time", "[sim]") { @@ -70,19 +67,9 @@ TEST_CASE("Measurement time", "[sim]") { tm.AddSample(7ms); tm.AddSample(9ms); - REQUIRE(tm.Min() == 2ms); - REQUIRE(tm.Max() == 9ms); - REQUIRE(tm.Mean() == 5ms); - REQUIRE(tm.Median() == 5ms); - REQUIRE(tm.StdDev() == 2ms); -} - -TEST_CASE("Measurement median", "[sim]") { - sim::DataMeasurement dm; - REQUIRE(dm.Median() == 0); - - dm.AddSample(123); - REQUIRE(dm.Median() == 123); + REQUIRE(tm.Size() == 8); + REQUIRE(tm.Samples() == std::vector( + {2ms, 4ms, 4ms, 4ms, 5ms, 5ms, 7ms, 9ms})); } TEST_CASE("Measurement samples", "[sim]") { @@ -95,5 +82,5 @@ TEST_CASE("Measurement samples", "[sim]") { dm.AddSample(22); REQUIRE(dm.Size() == 2); - REQUIRE(dm.Samples() == std::vector{22, 42}); + REQUIRE(dm.Samples() == std::vector{42, 22}); } diff --git a/test/scl/simulation/test_mem_channel_buffer.cc b/test/scl/simulation/test_mem_channel_buffer.cc index 82b7a07..e0ab631 100644 --- a/test/scl/simulation/test_mem_channel_buffer.cc +++ b/test/scl/simulation/test_mem_channel_buffer.cc @@ -33,17 +33,18 @@ TEST_CASE("Simulation MemoryBackedChannelBuffer", "[sim]") { std::vector data = {1, 2, 3, 4}; - chl0->Write(data); + chl0->Write(data.data(), data.size()); REQUIRE(chl0->Size() == 0); REQUIRE(chl1->Size() == 4); - auto d = chl1->Read(2); + std::vector d(2); + chl1->Read(d.data(), 2); REQUIRE(d == std::vector{1, 2}); REQUIRE(chl1->Size() == 2); - auto e = chl1->Read(2); - REQUIRE(e == std::vector{3, 4}); + chl1->Read(d.data(), 2); + REQUIRE(d == std::vector{3, 4}); REQUIRE(chl1->Size() == 0); } @@ -60,18 +61,19 @@ TEST_CASE("Simulation MemoryBackedChannelBuffer rollback", "[sim]") { local->Prepare(); - local->Write(data); + local->Write(data.data(), data.size()); REQUIRE(remote->Size() == 4); local->Rollback(); REQUIRE(remote->Size() == 0); - remote->Write(data); + remote->Write(data.data(), data.size()); local->Prepare(); REQUIRE(local->Size() == 4); - local->Read(2); + std::vector d(2); + local->Read(d.data(), 2); REQUIRE(local->Size() == 2); local->Rollback(); @@ -85,23 +87,24 @@ TEST_CASE("Simulation MemoryBackedChannelBuffer rollback", "[sim]") { local->Prepare(); - local->Write(data); + local->Write(data.data(), data.size()); local->Commit(); local->Prepare(); - local->Write(data); + local->Write(data.data(), data.size()); REQUIRE(remote->Size() == 8); local->Rollback(); REQUIRE(remote->Size() == 4); - remote->Write(data); + remote->Write(data.data(), data.size()); local->Prepare(); REQUIRE(local->Size() == 4); - local->Read(2); + std::vector d(2); + local->Read(d.data(), 2); REQUIRE(local->Size() == 2); local->Rollback(); @@ -112,16 +115,17 @@ TEST_CASE("Simulation MemoryBackedChannelBuffer rollback", "[sim]") { auto lo = sim::MemoryBackedChannelBuffer::CreateLoopback(); lo->Prepare(); - lo->Write(data); + lo->Write(data.data(), data.size()); REQUIRE(lo->Size() == 4); lo->Commit(); lo->Prepare(); - lo->Write(data); + lo->Write(data.data(), data.size()); REQUIRE(lo->Size() == 8); - auto d3 = lo->Read(3); + std::vector d3(3); + lo->Read(d3.data(), 3); REQUIRE(d3 == std::vector{1, 2, 3}); REQUIRE(lo->Size() == 5); diff --git a/test/scl/simulation/test_result.cc b/test/scl/simulation/test_result.cc index f018f12..8ee16ae 100644 --- a/test/scl/simulation/test_result.cc +++ b/test/scl/simulation/test_result.cc @@ -113,11 +113,11 @@ TEST_CASE("Simulation result sent recv", "[sim]") { Stop()); auto r = sim::Result::Create(trace); - REQUIRE(r[0].TransferAmounts(2).sent.Mean() == 0); - REQUIRE(r[0].TransferAmounts(2).recv.Mean() == 444); + REQUIRE(r[0].TransferAmounts(2).sent.Samples()[0] == 0); + REQUIRE(r[0].TransferAmounts(2).recv.Samples()[0] == 444); - REQUIRE(r[0].TransferAmounts(1).sent.Mean() == 123 + 22); - REQUIRE(r[0].TransferAmounts(1, "bar").sent.Mean() == 22); + REQUIRE(r[0].TransferAmounts(1).sent.Samples()[0] == 123 + 22); + REQUIRE(r[0].TransferAmounts(1, "bar").sent.Samples()[0] == 22); std::vector expected = {1, 2, 3}; REQUIRE_THAT(r[0].Interactions(), Catch::Matchers::UnorderedEquals(expected)); @@ -126,16 +126,43 @@ TEST_CASE("Simulation result sent recv", "[sim]") { Catch::Matchers::UnorderedEquals(expected_bar)); } +namespace { + +std::shared_ptr Checkpoint(const std::string& message) { + return std::make_shared(util::Time::Duration::zero(), + message); +} + +} // namespace + +TEST_CASE("Simulation result with checkpoint", "[sim]") { + TraceT trace = CREATE_TRACE(Start(), + BeginSegment(), + Checkpoint("x"), + EndSegment(), + BeginSegment(), + Checkpoint("x"), + Checkpoint("y"), + EndSegment(), + Stop()); + + auto r = sim::Result::Create(trace); + REQUIRE(r[0].Checkpoint("x").Size() == 1); + REQUIRE(r[0].Checkpoint("y").Size() == 1); +} + TEST_CASE("Simulation result write", "[sim]") { // TODO: This doesn't really test anything besides that Write is // stable(-ish). Ideally, the test should check that the result is consistent // with the content of a file on disk, but that likely requires that Write is - // deterministic, which is not the case base writes a ton of unordered maps. + // deterministic, which is not the case because writes a ton of unordered + // maps. TraceT trace = CREATE_TRACE(Start(), BeginSegment(), Send(0, 1, 123), Recv(0, 2, 444), + Checkpoint("x"), EndSegment(), BeginSegment("bar"), Send(0, 3, 42), @@ -151,3 +178,12 @@ TEST_CASE("Simulation result write", "[sim]") { r[0].Write(ss1); REQUIRE(ss0.str() == ss1.str()); } + +TEST_CASE("Simulation result write trace invalid replication", "[sim]") { + TraceT trace = CREATE_TRACE(Start(), Stop()); + auto r = sim::Result::Create(trace); + + REQUIRE_THROWS_MATCHES(r[0].WriteTrace(std::cout, 42), + std::invalid_argument, + Catch::Matchers::Message("invalid replication")); +} diff --git a/test/scl/simulation/test_simulator.cc b/test/scl/simulation/test_simulator.cc index 8583de6..13f2607 100644 --- a/test/scl/simulation/test_simulator.cc +++ b/test/scl/simulation/test_simulator.cc @@ -15,594 +15,276 @@ * along with this program. If not, see . */ +#include #include #include +#include #include #include #include +#include #include +#include #include "../protocol/beaver.h" #include "scl/math/fp.h" #include "scl/protocol/base.h" #include "scl/protocol/env.h" #include "scl/simulation/config.h" +#include "scl/simulation/manager.h" #include "scl/simulation/result.h" #include "scl/simulation/simulator.h" #include "scl/ss/additive.h" #include "scl/util/prg.h" using namespace scl; +using namespace std::chrono_literals; using FF = math::Fp<61>; -using namespace std::chrono_literals; +using Parties = std::vector>; namespace { -template -void ApproxDuration(util::Time::Duration d, T v, T b) { - if (v > d) { - REQUIRE(v - d <= b); - } else { - REQUIRE(d - v <= b); - } -} - -std::size_t KB(std::size_t bytes) { - return 1000 * bytes; +template +std::unique_ptr CreateParty(Ts&&... init_args) { + return std::make_unique(std::forward(init_args)...); } -std::size_t MB(std::size_t bytes) { - return 1000 * KB(bytes); +auto RecvTimeDefaultConf(std::size_t n) { + static const auto dft = sim::ChannelConfig::Default(); + return sim::ComputeRecvTime(dft, n); } } // namespace -TEST_CASE("ComputeRecvTime default config", "[sim]") { - // https://wintelguy.com/wanperf.pl - // parameters: - // Link bandwidth (Mbit/s): 1 - // RTT (millisecond): 100 - // Packet loss (%): 0 - // MTU (Byte): 1500 - // L1/L2 frame overhead (Byte): 0 <-- not accounted for in scl - // TCP/IP (v4) header overhead (Byte): 40 - // TCP window (RWND) size (Byte): 65536 - // File size (MByte): 1 - - const auto cfg = sim::SimulatedNetworkConfig::Default(); - const auto tenMB = MB(10); - const auto t = sim::ComputeRecvTime(cfg, tenMB); - ApproxDuration(t, 82s, 1s); -} - -TEST_CASE("ComputeRecvTime lossy", "[sim]") { - const auto cfg = - sim::SimulatedNetworkConfig::Builder().PackageLoss(0.001).Build(); - const auto tenMB = MB(10); - const auto t = sim::ComputeRecvTime(cfg, tenMB); - ApproxDuration(t, 82s, 1s); -} - -/** - * @brief Protocol with many rounds. - * - * This protocol runs for 101 steps. Each of the first 100 steps send a single - * int while the last step receives all the data sent. Each party will send some - * data to the next party. - * - * The output of the protocol is a boolean indicating if the received data - * matched the data sent. - */ -struct LotsOfDataProtocol { - struct Two final : proto::Protocol { +struct SimpleSendRecvProtocol { + struct Sender final : public proto::Protocol { std::unique_ptr Run(proto::Env& env) override { - std::vector data(100); - auto& network = env.network; - const auto id = network.MyId(); - const auto rid = id == 0 ? network.Size() - 1 : id - 1; - for (std::size_t i = 0; i < 100; ++i) { - auto p = network.Previous()->Recv(); - if (!p.has_value()) { - output = false; - } else { - std::size_t rid_c = p.value().Read(); - output &= rid_c == rid + i; - } - } - network.Previous()->Close(); + net::Packet p; + p << (std::size_t)123; + p << (int)-100; + env.network.Other()->Send(p); + env.network.Close(); return nullptr; - }; - - std::any Output() const override { - return output; } - - bool output = true; }; - struct One final : proto::Protocol { - One(int counter = 0) : counter(counter){}; + struct Receiver final : public proto::Protocol { std::unique_ptr Run(proto::Env& env) override { - auto& network = env.network; - const auto id = network.MyId(); - net::Packet p; - p << (int)(id + counter); - network.Next()->Send(p); - if (counter > 100) { - return std::make_unique(); - } - return std::make_unique(counter + 1); - }; - - int counter; - }; -}; - -TEST_CASE("Simulation many", "[sim]") { - const auto n_parties = 10; - - std::vector outputs(n_parties, false); - - const auto output_cb = [&outputs](std::size_t id, const std::any& output) { - REQUIRE(output.has_value()); - outputs[id] = std::any_cast(output); - }; - - std::vector> parties; - for (std::size_t i = 0; i < n_parties; ++i) { - parties.emplace_back(std::make_unique()); - } - - const auto r = - sim::Simulate(std::move(parties), sim::DefaultConfigCreator(), output_cb); + auto p = env.network.Other()->Recv(true); - for (std::size_t i = 0; i < n_parties; ++i) { - REQUIRE(outputs[i]); - } -} - -/** - * @brief Simple protocol where one party sends a boolean to another party. - */ -struct SendRecvProtocol { - struct Sender final : proto::Protocol { - std::unique_ptr Run(proto::Env& env) override { - net::Packet p; - p << true; - env.network.Other()->Send(p); - return nullptr; - } - }; + if (!p.has_value()) { + throw std::runtime_error("expected data"); + } + auto& v = p.value(); - struct Receiver final : proto::Protocol { - Receiver(const std::function& cb) : cb(cb){}; + is_correct = v.Read() == 123; + is_correct &= v.Read() == -100; - std::unique_ptr Run(proto::Env& env) override { - cb(); - auto p = env.network.Other()->Recv(); - output = p.has_value() && p.value().Read(); + env.network.Close(); return nullptr; - } + }; std::any Output() const override { - return output; + return is_correct; } - bool output; - std::function cb; + bool is_correct = false; }; }; -TEST_CASE("Simulation result trace", "[sim]") { - std::vector> parties; - parties.emplace_back(std::make_unique([]() {})); - parties.emplace_back(std::make_unique()); - - const auto r = sim::Simulate(std::move(parties), sim::DefaultConfigCreator()); - - std::stringstream ss; - r[0].WriteTrace(ss, 0); +namespace { - std::string line; +void VerifySendRecvProtocolResult(const std::vector& result) { + REQUIRE(result.size() == 2); - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("START")); + const auto& r0 = result[0]; + REQUIRE(r0.SegmentNames().size() == 1); + REQUIRE(r0.SegmentNames()[0] == proto::Protocol::DEFAULT_NAME); - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("SEGMENT_BEGIN")); + const auto et0 = r0.ExecutionTime(); + REQUIRE(et0.Size() == 1); + REQUIRE(et0.Samples()[0] < 1ms); - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("PACKET_RECV")); - - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("OUTPUT")); + const auto et1 = result[1].ExecutionTime(); + REQUIRE(et1.Size() == 1); + const auto bytes_recv = + sizeof(int) + sizeof(std::size_t) + sizeof(net::Packet::SizeType); + REQUIRE(et1.Samples()[0] < RecvTimeDefaultConf(bytes_recv) + 1ms); +} - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("SEGMENT_END")); +} // namespace - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("STOP")); +TEST_CASE("Simulate SimpleSendRecvProtocol", "[sim]") { + Parties p; + p.emplace_back(CreateParty()); + p.emplace_back(CreateParty()); - REQUIRE_THROWS_MATCHES(r[0].WriteTrace(ss, 42), - std::invalid_argument, - Catch::Matchers::Message("invalid iteration")); + const auto result = sim::Simulate(std::move(p)); + VerifySendRecvProtocolResult(result); +} - ss.str(""); - r[0].WriteTrace(ss, 0, proto::Protocol::kDefaultName); +TEST_CASE("Simulate SimpleSendRecvProtocol reverse", "[sim]") { + Parties p; + p.emplace_back(CreateParty()); + p.emplace_back(CreateParty()); - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("SEGMENT_BEGIN")); + const auto result = sim::Simulate(std::move(p)); + VerifySendRecvProtocolResult({result[1], result[0]}); +} - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("PACKET_RECV")); +namespace { - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("OUTPUT")); +void VerifyType(std::shared_ptr event, sim::Event::Type type) { + REQUIRE(event->EventType() == type); +} +void VerifyTypeString(std::stringstream& ss, + const std::string& event_type_str) { + std::string line; std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith("SEGMENT_END")); + REQUIRE_THAT(line, Catch::Matchers::StartsWith(event_type_str)); } -TEST_CASE("Simulation odd/even iterations", "[sim]") { - sim::DataMeasurement m_even; - sim::DataMeasurement m_odd; - - // Cannot use SECTION here as m_even somehow gets overwritten with garbage... - - { - const auto creator = []() { - std::vector> parties; - parties.emplace_back(std::make_unique()); - parties.emplace_back( - std::make_unique([]() {})); - return parties; - }; +} // namespace - const auto r = sim::Simulate(creator, sim::DefaultConfigCreator(), 2); - m_even = r[0].TransferAmounts().recv; +TEST_CASE("Simulate SimpleSendRecvProtocol trace", "[sim]") { + Parties p; + p.emplace_back(CreateParty()); + p.emplace_back(CreateParty()); + + const auto result = sim::Simulate(std::move(p)); + + SECTION("Sender") { + const auto& sender_trace = result[0].Trace(0); + + VerifyType(sender_trace[0], sim::Event::Type::START); + VerifyType(sender_trace[1], sim::Event::Type::SEGMENT_BEGIN); + VerifyType(sender_trace[2], sim::Event::Type::PACKET_SEND); + VerifyType(sender_trace[3], sim::Event::Type::CLOSE); // to self + VerifyType(sender_trace[4], sim::Event::Type::CLOSE); // to other + VerifyType(sender_trace[5], sim::Event::Type::SEGMENT_END); + VerifyType(sender_trace[6], sim::Event::Type::STOP); + + std::stringstream ss; + result[0].WriteTrace(ss, 0); + VerifyTypeString(ss, "START"); + VerifyTypeString(ss, "SEGMENT_BEGIN"); + VerifyTypeString(ss, "PACKET_SEND"); + VerifyTypeString(ss, "CLOSE"); + VerifyTypeString(ss, "CLOSE"); + VerifyTypeString(ss, "SEGMENT_END"); + VerifyTypeString(ss, "STOP"); } - { - const auto creator = []() { - std::vector> parties; - parties.emplace_back(std::make_unique()); - parties.emplace_back( - std::make_unique([]() {})); - return parties; - }; - - const auto r = sim::Simulate(creator, sim::DefaultConfigCreator(), 3); - m_odd = r[0].TransferAmounts().recv; + SECTION("Receiver") { + const auto& receiver_trace = result[1].Trace(0); + + VerifyType(receiver_trace[0], sim::Event::Type::START); + VerifyType(receiver_trace[1], sim::Event::Type::SEGMENT_BEGIN); + VerifyType(receiver_trace[2], sim::Event::Type::PACKET_RECV); + VerifyType(receiver_trace[3], sim::Event::Type::CLOSE); // to self + VerifyType(receiver_trace[4], sim::Event::Type::CLOSE); // to other + VerifyType(receiver_trace[5], sim::Event::Type::OUTPUT); + VerifyType(receiver_trace[6], sim::Event::Type::SEGMENT_END); + VerifyType(receiver_trace[7], sim::Event::Type::STOP); + + std::stringstream ss; + result[1].WriteTrace(ss, 0); + + VerifyTypeString(ss, "START"); + VerifyTypeString(ss, "SEGMENT_BEGIN"); + VerifyTypeString(ss, "PACKET_RECV"); + VerifyTypeString(ss, "CLOSE"); + VerifyTypeString(ss, "CLOSE"); + VerifyTypeString(ss, "OUTPUT"); + VerifyTypeString(ss, "SEGMENT_END"); + VerifyTypeString(ss, "STOP"); } - - REQUIRE(m_even.Mean() == m_odd.Mean()); - REQUIRE(m_even.Median() == m_odd.Median()); - REQUIRE(m_even.Min() == m_odd.Min()); - REQUIRE(m_even.Max() == m_odd.Max()); } -TEST_CASE("Simulation receive out-of-order", "[sim]") { - SECTION("Receive before send") { - int called = 0; - const auto cb = [&called]() { called++; }; - - std::vector> parties; - parties.emplace_back(std::make_unique(cb)); - parties.emplace_back(std::make_unique()); - - const auto r = - sim::Simulate(std::move(parties), sim::DefaultConfigCreator()); - - // receive bool (1 byte) + packet size (4 bytes). - REQUIRE(r[0].TransferAmounts().recv.Max() == 5.0); - ApproxDuration(r[0].ExecutionTime().Max(), 100ms, 1ms); - ApproxDuration(r[1].ExecutionTime().Max(), 1ms, 1ms); - - // Ensure that receiver was called twice. - REQUIRE(called == 2); - } - - SECTION("Send before receive") { - int called = 0; - const auto cb = [&called]() { called++; }; - - std::vector> parties_; - parties_.emplace_back(std::make_unique()); - parties_.emplace_back(std::make_unique(cb)); +TEST_CASE("Simulate null protocol", "[sim]") { + Parties p; + p.emplace_back(nullptr); - const auto r_ = - sim::Simulate(std::move(parties_), sim::DefaultConfigCreator()); + const auto result = sim::Simulate(std::move(p)); - // sent bool (1 byte) + packet size (4 bytes). - REQUIRE(r_[0].TransferAmounts().sent.Max() == 5.0); - ApproxDuration(r_[1].ExecutionTime().Max(), 100ms, 1ms); - ApproxDuration(r_[0].ExecutionTime().Max(), 1ms, 1ms); - - REQUIRE(called == 1); - } + REQUIRE(result.size() == 1); + const auto trace = result[0].Trace(0); + REQUIRE(trace[0]->EventType() == sim::Event::Type::START); + REQUIRE(trace[1]->EventType() == sim::Event::Type::STOP); } -/** - * @brief Two party protocol that uses HasData. - * - * This protocol captures both failure cases for simulating a HasData call. Both - * failure cases arise - */ -struct HasDataProtocol { - struct Bob; - struct Alice; - - struct Alice final : public proto::Protocol { - Alice(bool sleep, bool exit_early, bool send) - : sleep(sleep), exit_early(exit_early), send(send){}; - +struct PingPongProtocol { + struct Ping final : public proto::Protocol { std::unique_ptr Run(proto::Env& env) override { - if (sleep) { - env.thread_ctx->Sleep(50); - } - if (send) { - env.network.Other()->Send(42); - } - if (exit_early) { - return nullptr; - } - return std::make_unique(); + unsigned char data[] = {'a', 'b', 'c'}; + env.network.Other()->Send(data, 3); + env.thread_ctx->Sleep(1000); + return std::make_unique(); + } + std::string Name() const override { + return "Ping"; } - - struct Dummy final : public proto::Protocol { - std::unique_ptr Run(proto::Env& env) override { - (void)env; - return nullptr; - } - }; - - bool sleep; - bool exit_early; - bool send; }; - struct Bob final : public proto::Protocol { - Bob(bool sleep) : sleep(sleep){}; - + struct Pong final : public proto::Protocol { std::unique_ptr Run(proto::Env& env) override { - if (sleep) { - env.thread_ctx->Sleep(50); - } - output = env.network.Other()->HasData(); - return nullptr; + unsigned char data[3] = {0}; + env.network.Other()->Recv(data, 3); + bool good = data[0] == 'a' && data[1] == 'b' && data[2] == 'c'; + env.clock->Checkpoint(good ? "yay" : "boo"); + return std::make_unique(); } - - std::any Output() const override { - return output; + std::string Name() const override { + return "Pong"; } - - bool output; - bool sleep; - }; -}; - -namespace { - -bool TestHasData(bool alice_sleep, - bool alice_exit_early, - bool alice_send, - bool bob_sleep, - bool bob_before_alice = false) { - using Alice = HasDataProtocol::Alice; - using Bob = HasDataProtocol::Bob; - - std::optional output; - const auto cb = [&output](std::size_t id, const std::any& o) { - (void)id; - output = std::any_cast(o); - }; - - std::vector> parties; - if (bob_before_alice) { - parties.emplace_back(std::make_unique(bob_sleep)); - parties.emplace_back( - std::make_unique(alice_sleep, alice_exit_early, alice_send)); - } else { - parties.emplace_back( - std::make_unique(alice_sleep, alice_exit_early, alice_send)); - parties.emplace_back(std::make_unique(bob_sleep)); - } - - const auto r = - sim::Simulate(std::move(parties), sim::DefaultConfigCreator(), cb); - - if (output.has_value()) { - return output.value(); - } - FAIL("output did not have a value"); - return false; -} - -} // namespace - -TEST_CASE("Simulation HasData", "[sim]") { - SECTION("Alice never sends data") { - const auto result = TestHasData(true, false, false, false); - REQUIRE_FALSE(result); - } - - SECTION("Alice sends data after Bob") { - const auto result = TestHasData(true, false, true, false); - REQUIRE_FALSE(result); - } - - SECTION("Alice sends data before Bob") { - const auto result = TestHasData(false, false, true, true); - REQUIRE(result); - } - - SECTION("Bob before Alice") { - const auto result = TestHasData(false, true, true, true, true); - REQUIRE(result); - } -} - -TEST_CASE("Simulation Beaver", "[sim]") { - auto creator = []() { - std::vector> parties; - auto prg = util::PRG::Create(); - auto xs = ss::AdditiveShare(FF(42), 2, prg); - auto ys = ss::AdditiveShare(FF(11), 2, prg); - auto ts = test::RandomTriple(prg); - - parties.emplace_back(test::BeaverMul::Create(xs[0], ys[0], ts[0])); - parties.emplace_back(test::BeaverMul::Create(xs[1], ys[1], ts[1])); - - return parties; }; - - const auto result = sim::Simulate(creator, sim::DefaultConfigCreator(), 10); - - SECTION("segment names") { - const auto segment_names = result[0].SegmentNames(); - std::vector expected_names = {"init", "finalize"}; - REQUIRE_THAT(segment_names, Catch::Matchers::Contains(expected_names)); - } - - SECTION("running time") { - ApproxDuration(result[0].ExecutionTime().Mean(), 113ms, 1ms); - ApproxDuration(result[1].ExecutionTime().Mean(), 113ms, 1ms); - - ApproxDuration(result[0].ExecutionTime("init").Mean(), 1ms, 1ms); - ApproxDuration(result[1].ExecutionTime("finalize").Mean(), 113ms, 1ms); - } - - SECTION("transfer") { - // Each party sends 8 bytes (vec lenghts) + 100 field elements of 8 bytes - // each. - REQUIRE(result[0].TransferAmounts().sent.Mean() == 8 + 4 * 100 * 8); - REQUIRE(result[1].TransferAmounts().sent.Mean() == 8 + 4 * 100 * 8); - - REQUIRE(result[0].TransferAmounts().recv.Mean() == 8 + 4 * 100 * 8); - REQUIRE(result[1].TransferAmounts().recv.Mean() == 8 + 4 * 100 * 8); - - // The simple beaver protocol sends is deterministic wrt. data amounts - REQUIRE(result[0].TransferAmounts().recv.Min() == - result[0].TransferAmounts().recv.Max()); - - REQUIRE(result[1].TransferAmounts().recv.Min() == - result[1].TransferAmounts().recv.Max()); - } -} - -struct SinglePartyProtocol final : public proto::Protocol { - SinglePartyProtocol(bool send = true) : send(send){}; - - std::unique_ptr Run(proto::Env& env) override { - if (send) { - std::vector data(100, 42); - net::Packet p; - p << data; - env.network.Party(0)->Send(p); - return std::make_unique(false); - } - auto p = env.network.Party(0)->Recv(); - if (p.has_value()) { - output = p.value().Read>(); - } - return nullptr; - } - - bool send; - std::vector output; }; -TEST_CASE("Simulation one party", "[sim]") { - std::vector> party; - party.push_back(std::make_unique()); +struct PingPongManager final : public sim::Manager { + PingPongManager(std::size_t replications) : sim::Manager(replications) {} - auto r = sim::Simulate(std::move(party), sim::DefaultConfigCreator()); - // The default configuration should ensure that communication locally - // happens almost instantly - ApproxDuration(r[0].ExecutionTime().Max(), 1ms, 1ms); -} - -struct ChunkedRecvProtocol final : public proto::Protocol { - std::unique_ptr Run(proto::Env& env) override { - if (env.network.MyId() == 0) { - unsigned char buf[5] = {4, 5, 6, 7, 8}; - env.network.Other()->Send(buf, 3); - env.network.Other()->Send(buf + 3, 2); - } else { - unsigned char buf[5] = {0}; - env.network.Other()->Recv(buf, 5); - bool good = true; - for (std::size_t i = 0; i < 5; ++i) { - good &= (4 + i) == buf[i]; - } - m_output = good; - } - return nullptr; + std::vector> Protocol() override { + Parties p; + p.emplace_back(std::make_unique()); + p.emplace_back(std::make_unique()); + return p; } - std::any Output() const override { - return m_output; + bool Terminate(std::size_t party_id, + const sim::Context::View& view) override { + const auto latest_time = view.Trace(party_id).back()->Timestamp(); + return latest_time > 10s; } - - bool m_output; }; -TEST_CASE("Simulation chunked receive", "[sim]") { - bool correct = false; - auto cb = [&correct](auto id, std::any output) { - if (id == 1) { - correct = std::any_cast(output); - } - }; +TEST_CASE("Simulate PingPongProtocol", "[sim]") { + auto m = std::make_unique(1); + const auto result = sim::Simulate(std::move(m)); - std::vector> parties; - parties.emplace_back(std::make_unique()); - parties.emplace_back(std::make_unique()); - sim::Simulate(std::move(parties), sim::DefaultConfigCreator(), cb); + const auto last_event_p0 = result[0].Trace(0).back(); + VerifyType(last_event_p0, sim::Event::Type::KILLED); + REQUIRE(last_event_p0->Timestamp() >= 10000ms); - REQUIRE(correct); + const auto last_event_p1 = result[1].Trace(0).back(); + VerifyType(last_event_p1, sim::Event::Type::KILLED); + REQUIRE(last_event_p1->Timestamp() >= 10000ms); } -struct NonBlockRecvProtocol final : public proto::Protocol { - NonBlockRecvProtocol(bool sleep = false) : sleep(sleep) {} +TEST_CASE("Simulate PingPongProtocol trace", "[sim]") { + auto m = std::make_unique(1); + const auto result = sim::Simulate(std::move(m)); - std::unique_ptr Run(proto::Env& env) override { - if (env.network.MyId() == 0) { - net::Packet p; - p << 1 << 2 << 3; - env.thread_ctx->Sleep(10); - env.network.Party(1)->Send(p); - } else { - if (sleep) { - env.thread_ctx->Sleep(20); - } - - auto p = env.network.Party(0)->Recv(false); - } - return nullptr; + std::stringstream ss; + std::string line; + result[0].WriteTrace(ss, 0, "Ping"); + + // ping/pong runs for 10 iterations + for (std::size_t i = 0; i < 10; ++i) { + VerifyTypeString(ss, "SEGMENT_BEGIN"); + VerifyTypeString(ss, "SEND"); + VerifyTypeString(ss, "SLEEP"); + VerifyTypeString(ss, "SEGMENT_END"); } - - bool sleep; -}; - -TEST_CASE("Simulation non-block recv", "[sim]") { - std::vector> p_no_data; - p_no_data.emplace_back(std::make_unique()); - p_no_data.emplace_back(std::make_unique(false)); - - const auto r0 = - sim::Simulate(std::move(p_no_data), sim::DefaultConfigCreator()); - - // execution is instant because no data is available. - ApproxDuration(r0[1].ExecutionTime().Max(), 1ms, 1ms); - - std::vector> p_data; - p_data.emplace_back(std::make_unique()); - p_data.emplace_back(std::make_unique(true)); - - const auto r1 = sim::Simulate(std::move(p_data), sim::DefaultConfigCreator()); - - // execution has to wait for the packet, because data was available. - ApproxDuration(r1[1].ExecutionTime().Max(), 110ms, 10ms); } diff --git a/test/scl/ss/test_feldman.cc b/test/scl/ss/test_feldman.cc index ecf6c25..d8cab26 100644 --- a/test/scl/ss/test_feldman.cc +++ b/test/scl/ss/test_feldman.cc @@ -28,7 +28,7 @@ using namespace scl; TEST_CASE("Feldman", "[ss]") { using EC = math::EC; - using FF = EC::Order; + using FF = EC::ScalarField; auto prg = util::PRG::Create(); std::size_t t = 4; diff --git a/test/scl/ss/test_shamir.cc b/test/scl/ss/test_shamir.cc index f089702..11b09de 100644 --- a/test/scl/ss/test_shamir.cc +++ b/test/scl/ss/test_shamir.cc @@ -21,6 +21,7 @@ #include "../gf7.h" #include "scl/math/fp.h" #include "scl/math/lagrange.h" +#include "scl/math/poly.h" #include "scl/math/vec.h" #include "scl/ss/shamir.h" #include "scl/util/prg.h" @@ -77,6 +78,36 @@ TEST_CASE("Shamir reconstruct detect", "[ss]") { Catch::Matchers::Message("error detected during recovery")); } +namespace { + +math::Vec ShareWithDifferentAlphas(util::PRG& prg, + std::size_t t, + std::size_t n) { + auto c = math::Vec::Random(t + 1, prg); + c[0] = FF(123); + const auto p = math::Polynomial::Create(c); + + std::vector shares; + shares.reserve(n); + for (std::size_t i = 0; i < n; ++i) { + shares.emplace_back(p.Evaluate(FF{(int)i + 42})); + } + return math::Vec(shares); +} + +} // namespace + +TEST_CASE("Shamir reconstruct different x and alphas", "[ss]") { + auto prg = util::PRG::Create("shamir detect2"); + + const auto shares = ShareWithDifferentAlphas(prg, 3, 7); + const auto alphas = math::Vec::Range(42, 50); + + REQUIRE(ss::ShamirRecoverD(shares, alphas, FF(0)) == FF(123)); + + REQUIRE(ss::ShamirRecoverD(shares, alphas, alphas[0]) == shares[0]); +} + TEST_CASE("Shamir reconstruct correct", "[sim]") { auto prg = util::PRG::Create("shamir correct"); auto shares = ss::ShamirShare(FF(123), 2, 7, prg); @@ -95,6 +126,21 @@ TEST_CASE("Shamir reconstruct correct", "[sim]") { Catch::Matchers::Message("could not correct shares")); } +TEST_CASE("Shamir reconstruct correct different alphas", "[ss]") { + auto prg = util::PRG::Create("shamir correct2"); + + auto shares = ShareWithDifferentAlphas(prg, 2, 7); + const auto alphas = math::Vec::Range(42, 50); + + REQUIRE(ss::ShamirRecoverC(shares, alphas).f.ConstantTerm() == FF(123)); + + shares[4] = FF(5555); + + const auto r = ss::ShamirRecoverC(shares, alphas); + REQUIRE(r.f.ConstantTerm() == FF(123)); + REQUIRE(r.err.Evaluate(alphas[4]) == FF(0)); +} + TEST_CASE("BerlekampWelch", "[ss][math]") { // https://en.wikipedia.org/wiki/Berlekamp%E2%80%93Welch_algorithm#Example diff --git a/test/scl/util/test_merkle.cc b/test/scl/util/test_merkle.cc new file mode 100644 index 0000000..9fe745a --- /dev/null +++ b/test/scl/util/test_merkle.cc @@ -0,0 +1,88 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include + +#include "scl/util/hash.h" +#include "scl/util/merkle.h" + +using namespace scl; + +using Mrkl = util::MerkleTree, std::string_view>; + +namespace { + +util::Hash<256>::DigestType Hash(std::string_view thing) { + return util::Hash<256>{}.Update(thing).Finalize(); +} + +util::Hash<256>::DigestType Hash(util::Hash<256>::DigestType a, + util::Hash<256>::DigestType b) { + return util::Hash<256>{}.Update(a).Update(b).Finalize(); +} + +} // namespace + +TEST_CASE("Merkle hash", "[misc]") { + auto h_abcd = Hash(Hash(Hash("a"), Hash("b")), Hash(Hash("c"), Hash("d"))); + auto m_abcd = Mrkl::Hash({"a", "b", "c", "d"}); + REQUIRE(h_abcd == m_abcd); + + auto h_xyvu = Hash(Hash(Hash("x"), Hash("y")), Hash(Hash("v"), Hash("u"))); + auto h_abcdxyvu = Hash(h_abcd, h_xyvu); + + auto m_abcdxyvu = Mrkl::Hash({"a", "b", "c", "d", "x", "y", "v", "u"}); + REQUIRE(h_abcdxyvu == m_abcdxyvu); +} + +TEST_CASE("Merkle hash odd size input", "[misc]") { + util::Hash<256>::DigestType z_digest; + z_digest.fill(0); + auto h_abc = Hash(Hash(Hash("a"), Hash("b")), Hash(Hash("c"), Hash("c"))); + auto m_abc = Mrkl::Hash({"a", "b", "c"}); + + REQUIRE(h_abc == m_abc); +} + +TEST_CASE("Merkle prove", "[misc]") { + std::vector data = {"a", "b", "c", "d", "e"}; + auto root = Mrkl::Hash(data); + + auto h_ab = Hash(Hash("a"), Hash("b")); + auto h_cd = Hash(Hash("c"), Hash("d")); + auto h_ee = Hash(Hash("e"), Hash("e")); + auto h_abcd = Hash(h_ab, h_cd); + auto h_eeee = Hash(h_ee, h_ee); + + REQUIRE(root == Hash(h_abcd, h_eeee)); + + auto proof = Mrkl::Prove(data, 3); + + // path = [H_c, H_ab, H_eeee] + // direction = [left, left, right] (true, true, false) + + REQUIRE(proof.direction.size() == 3); + REQUIRE(proof.path.size() == 3); + + REQUIRE(proof.direction == std::vector{true, true, false}); + + REQUIRE(proof.path[0] == Hash("c")); + REQUIRE(proof.path[1] == h_ab); + REQUIRE(proof.path[2] == h_eeee); + + REQUIRE(Mrkl::Verify("d", root, proof)); +} diff --git a/test/scl/util/test_sha3.cc b/test/scl/util/test_sha3.cc index 6bd9791..81858ca 100644 --- a/test/scl/util/test_sha3.cc +++ b/test/scl/util/test_sha3.cc @@ -17,13 +17,14 @@ #include +#include "scl/math/fp.h" #include "scl/util/digest.h" #include "scl/util/hash.h" using namespace scl; TEST_CASE("Sha3 empty hash", "[misc]") { - static const util::Digest<256>::Type SHA3_256_empty = { + static const util::Digest<256> SHA3_256_empty = { 0xa7, 0xff, 0xc6, 0xf8, 0xbf, 0x1e, 0xd7, 0x66, 0x51, 0xc1, 0x47, 0x56, 0xa0, 0x61, 0xd6, 0x62, 0xf5, 0x80, 0xff, 0x4d, 0xe4, 0x3b, 0x49, 0xfa, 0x82, 0xd8, 0x0a, 0x4b, 0x80, 0xf8, 0x43, 0x4a}; @@ -34,7 +35,7 @@ TEST_CASE("Sha3 empty hash", "[misc]") { } TEST_CASE("Sha3 abc hash", "[misc]") { - static const util::Digest<256>::Type SHA3_256_abc = { + static const util::Digest<256> SHA3_256_abc = { 0x3a, 0x98, 0x5d, 0xa7, 0x4f, 0xe2, 0x25, 0xb2, 0x04, 0x5c, 0x17, 0x2d, 0x6b, 0xd3, 0x90, 0xbd, 0x85, 0x5f, 0x08, 0x6e, 0x3e, 0x9d, 0x52, 0x5b, 0x46, 0xbf, 0xe2, 0x45, 0x11, 0x43, 0x15, 0x32}; @@ -46,7 +47,7 @@ TEST_CASE("Sha3 abc hash", "[misc]") { } TEST_CASE("Sha3-256 reference", "[misc]") { - static const util::Digest<256>::Type SHA3_256_0xa3_200_times = { + static const util::Digest<256> SHA3_256_0xa3_200_times = { 0x79, 0xf3, 0x8a, 0xde, 0xc5, 0xc2, 0x03, 0x07, 0xa9, 0x8e, 0xf7, 0x6e, 0x83, 0x24, 0xaf, 0xbf, 0xd4, 0x6c, 0xfd, 0x81, 0xb2, 0x2e, 0x39, 0x73, 0xc6, 0x5f, 0xa1, 0xbd, 0x9d, 0xe3, 0x17, 0x87}; @@ -69,7 +70,7 @@ TEST_CASE("Sha3-256 reference", "[misc]") { } TEST_CASE("Sha3-384 reference", "[misc]") { - static const util::Digest<384>::Type SHA3_384_0xa3_200_times = { + static const util::Digest<384> SHA3_384_0xa3_200_times = { 0x18, 0x81, 0xde, 0x2c, 0xa7, 0xe4, 0x1e, 0xf9, 0x5d, 0xc4, 0x73, 0x2b, 0x8f, 0x5f, 0x00, 0x2b, 0x18, 0x9c, 0xc1, 0xe4, 0x2b, 0x74, 0x16, 0x8e, 0xd1, 0x73, 0x26, 0x49, 0xce, 0x1d, 0xbc, 0xdd, 0x76, 0x19, 0x7a, 0x31, @@ -139,3 +140,12 @@ TEST_CASE("Sha3 hash array", "[misc]") { auto act = util::Hash<256>{}.Update(abc_arr).Finalize(); REQUIRE(ref == act); } + +TEST_CASE("Sha3 field elements", "[misc]") { + math::Fp<61> x(123); + math::Fp<61> y(555); + + auto hx = util::Hash<256>{}.Update(x).Finalize(); + auto hy = util::Hash<256>{}.Update(y).Finalize(); + REQUIRE(hx != hy); +}