diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a26ca2..1e4e9c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required( VERSION 3.14 ) -project( scl VERSION 6.1.0 DESCRIPTION "Secure Computation Library" ) +project( scl VERSION 6.2.0 DESCRIPTION "Secure Computation Library" ) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) @@ -39,6 +39,7 @@ set(SCL_SOURCE_FILES src/scl/util/prg.cc src/scl/util/sha3.cc src/scl/util/sha256.cc + src/scl/util/cmdline.cc src/scl/math/mersenne61.cc src/scl/math/mersenne127.cc @@ -100,6 +101,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") test/scl/util/test_sha3.cc test/scl/util/test_sha256.cc test/scl/util/test_ecdsa.cc + test/scl/util/test_cmdline.cc test/scl/gf7.cc test/scl/math/test_mersenne61.cc @@ -152,10 +154,8 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") include(Catch) include(${CMAKE_SOURCE_DIR}/cmake/CodeCoverage.cmake) - # Lower the max size of Vec/Mat reads to speed up tests - add_compile_definitions(MAX_VEC_READ_SIZE=1024) - add_compile_definitions(MAX_MAT_READ_SIZE=1024) add_compile_definitions(SCL_TEST_DATA_DIR="${CMAKE_SOURCE_DIR}/test/data/") + add_compile_definitions(SCL_UTIL_NO_EXIT_ON_ERROR) add_executable(scl_test ${SCL_SOURCE_FILES} ${SCL_TEST_SOURCE_FILES}) target_link_libraries(scl_test Catch2::Catch2 pthread) diff --git a/RELEASE.txt b/RELEASE.txt index ba20f1a..cb2d2b1 100644 --- a/RELEASE.txt +++ b/RELEASE.txt @@ -1,3 +1,9 @@ +6.2.0: More functionality for Number +- Add modulo operator to Number. +- Add some mathematical functions that operate on numbers. +- Make Number serializable; add Serializer specialization. +- Add a simple command-line argument parser. + 6.1.0: Extend serialization functionality - Make Write methods return the number of bytes written. - Make it possible to serialize vectors with arbitrary content. diff --git a/include/scl/math/number.h b/include/scl/math/number.h index 677a9f8..8f1d752 100644 --- a/include/scl/math/number.h +++ b/include/scl/math/number.h @@ -28,6 +28,43 @@ namespace scl::math { +class Number; + +/** + * @brief Compute the least common multiple of two numbers. + * @param a the first number. + * @param b the second number. + * @return \f$lcm(a, b)\f$. + */ +Number LCM(const Number& a, const Number& b); + +/** + * @brief Compute the greatest common divisor of two numbers. + * @param a the first number. + * @param b the second number. + * @return \f$gcd(a, b)\f$. + */ +Number GCD(const Number& a, const Number& b); + +/** + * @brief Compute the modular inverse of a number. + * @param val the value to invert. + * @param mod the modulus. + * @return \f$val^{-1} \mod mod \f$. + * @throws std::logic_error if \p val is not invertible. + * @throws std::invalid_argument if \p mod is 0. + */ +Number ModInverse(const Number& val, const Number& mod); + +/** + * @brief Compute a modular exponentiation. + * @param base the base. + * @param exp the exponent. + * @param mod the modulus. + * @return \f$base^{exp} \mod mod\f$. + */ +Number ModExp(const Number& base, const Number& exp, const Number& mod); + /** * @brief Arbitrary precision integer. */ @@ -35,12 +72,20 @@ class Number final : Print { public: /** * @brief Generate a random Number. - * @param bits the number of bits in the resulting number - * @param prg a prg for generating the random number - * @return a random Number + * @param bits the number of bits in the resulting number. + * @param prg a prg for generating the random number. + * @return a random Number. */ static Number Random(std::size_t bits, util::PRG& prg); + /** + * @brief Generate a random prime. + * @param bits the number of bits in the resulting prime. + * @param prg a prg for generating the random prime. + * @return a random prime. + */ + static Number RandomPrime(std::size_t bits, util::PRG& prg); + /** * @brief Read a Number from a string * @param str the string @@ -48,6 +93,13 @@ class Number final : Print { */ static Number FromString(const std::string& str); + /** + * @brief Read a number from a buffer. + * @param buf the buffer. + * @return a Number. + */ + static Number Read(const unsigned char* buf); + /** * @brief Construct a Number from an int. * @param value the int @@ -85,7 +137,7 @@ class Number final : Print { Number copy(number); swap(*this, copy); return *this; - }; + } /** * @brief Move assignment from a Number. @@ -95,7 +147,7 @@ class Number final : Print { Number& operator=(Number&& number) noexcept { swap(*this, number); return *this; - }; + } /** * @brief In-place addition of two numbers. @@ -105,7 +157,7 @@ class Number final : Print { Number& operator+=(const Number& number) { *this = *this + number; return *this; - }; + } /** * @brief Add two numbers. @@ -122,7 +174,7 @@ class Number final : Print { Number& operator-=(const Number& number) { *this = *this - number; return *this; - }; + } /** * @brief Subtract two Numbers. @@ -145,7 +197,7 @@ class Number final : Print { Number& operator*=(const Number& number) { *this = *this * number; return *this; - }; + } /** * @brief Multiply two Numbers. @@ -162,7 +214,7 @@ class Number final : Print { Number& operator/=(const Number& number) { *this = *this / number; return *this; - }; + } /** * @brief Divide two Numbers. @@ -171,6 +223,23 @@ class Number final : Print { */ Number operator/(const Number& number) const; + /** + * @brief In-place modulo operator. + * @param mod the modulus. + * @return this. + */ + Number& operator%=(const Number& mod) { + *this = *this % mod; + return *this; + } + + /** + * @brief Modulo operation. + * @param mod the modulus. + * @return \p this modulo \p mod. + */ + Number operator%(const Number& mod) const; + /** * @brief In-place left shift. * @param shift the amount to left shift @@ -179,7 +248,7 @@ class Number final : Print { Number& operator<<=(int shift) { *this = *this << shift; return *this; - }; + } /** * @brief Perform a left shift of a Number. @@ -196,7 +265,7 @@ class Number final : Print { Number& operator>>=(int shift) { *this = *this >> shift; return *this; - }; + } /** * @brief Perform a right shift of a Number. @@ -213,7 +282,7 @@ class Number final : Print { Number& operator^=(const Number& number) { *this = *this ^ number; return *this; - }; + } /** * @brief Exclusive or of two numbers. @@ -230,7 +299,7 @@ class Number final : Print { Number& operator|=(const Number& number) { *this = *this | number; return *this; - }; + } /** * @brief operator | @@ -247,7 +316,7 @@ class Number final : Print { Number& operator&=(const Number& number) { *this = *this & number; return *this; - }; + } /** * @brief operator & @@ -280,42 +349,47 @@ class Number final : Print { */ friend bool operator==(const Number& lhs, const Number& rhs) { return lhs.Compare(rhs) == 0; - }; + } /** * @brief In-equality of two numbers. */ friend bool operator!=(const Number& lhs, const Number& rhs) { return lhs.Compare(rhs) != 0; - }; + } /** * @brief Strictly less-than of two numbers. */ friend bool operator<(const Number& lhs, const Number& rhs) { return lhs.Compare(rhs) < 0; - }; + } /** * @brief Less-than-or-equal of two numbers. */ friend bool operator<=(const Number& lhs, const Number& rhs) { return lhs.Compare(rhs) <= 0; - }; + } /** * @brief Strictly greater-than of two numbers. */ friend bool operator>(const Number& lhs, const Number& rhs) { return lhs.Compare(rhs) > 0; - }; + } /** * @brief Greater-than-or-equal of two numbers. */ friend bool operator>=(const Number& lhs, const Number& rhs) { return lhs.Compare(rhs) >= 0; - }; + } + + /** + * @brief Get the size of this number in bytes. + */ + std::size_t ByteSize() const; /** * @brief Get the size of this Number in bits. @@ -340,7 +414,7 @@ class Number final : Print { */ bool Odd() const { return TestBit(0); - }; + } /** * @brief Test if this Number is even. @@ -348,7 +422,13 @@ class Number final : Print { */ bool Even() const { return !Odd(); - }; + } + + /** + * @brief Write this number to a buffer. + * @param buf the buffer. + */ + void Write(unsigned char* buf) const; /** * @brief Return a string representation of this Number. @@ -362,10 +442,17 @@ class Number final : Print { friend void swap(Number& first, Number& second) { using std::swap; swap(first.m_value, second.m_value); - }; + } private: mpz_t m_value; + + friend Number LCM(const Number& a, const Number& b); + friend Number GCD(const Number& a, const Number& b); + friend Number ModInverse(const Number& val, const Number& mod); + friend Number ModExp(const Number& base, + const Number& exp, + const Number& mod); }; } // namespace scl::math diff --git a/include/scl/serialization/math_serializers.h b/include/scl/serialization/math_serializers.h index 830fae9..444b0cc 100644 --- a/include/scl/serialization/math_serializers.h +++ b/include/scl/serialization/math_serializers.h @@ -20,6 +20,7 @@ #include "scl/math/ff.h" #include "scl/math/mat.h" +#include "scl/math/number.h" #include "scl/math/vec.h" #include "scl/serialization/serializer.h" @@ -163,6 +164,47 @@ struct Serializer> { } }; +/** + * @brief Serializer specialization for math::Number. + */ +template <> +struct Serializer { + /** + * @brief Get the serialized size of a math::Number. + * @param number the number. + * @return the serialized size of a math::Number. + * + * A math::Number is writte as size_and_sign | number where + * size_and_sign is a 4 byte value containing the byte size of + * the number and its sign. + */ + static std::size_t SizeOf(const math::Number& number) { + return number.ByteSize() + sizeof(std::uint32_t); + } + + /** + * @brief Write a number to a buffer. + * @param number the number. + * @param buf the buffer. + * @return the number of bytes written. + */ + static std::size_t Write(const math::Number& number, unsigned char* buf) { + number.Write(buf); + return SizeOf(number); + } + + /** + * @brief Read a math::Number from a buffer. + * @param number the number. + * @param buf the buffer. + * @return the number of bytes read. + */ + static std::size_t Read(math::Number& number, const unsigned char* buf) { + number = math::Number::Read(buf); + return SizeOf(number); + } +}; + } // namespace scl::seri #endif // SCL_SERIALIZATION_MATH_SERIALIZERS_H diff --git a/include/scl/util/cmdline.h b/include/scl/util/cmdline.h new file mode 100644 index 0000000..b3183d9 --- /dev/null +++ b/include/scl/util/cmdline.h @@ -0,0 +1,320 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_UTIL_CMDLINE_H +#define SCL_UTIL_CMDLINE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace scl::util { + +/** + * @brief Simple command line argument parser. + * + * ProgramOptions allows defining and parsing options for a program in a limited + * manner using a builder pattern. For example: + * + * @code + * auto p = ProgramOptions::Parser("some description") + * .Add(ProgramArg::Required("foo", "int", "foo description")) + * .Add(ProgramArg::Optional("bar", "bool", "123")) + * .Add(ProgramFlag("flag")) + * .Parse(argc, argv); + * @endcode + * + * The above snippet will parse the argv argument vector passed to + * a program looking for arguments -foo value and + * flag. The bar is optional and if not explicitly + * supplied, gets the default value "123". + */ +class ProgramOptions { + public: + class Parser; + + /** + * @brief Check if some argument has been provided. + * @param name the name of the argument. + * @return true if the argument was set, false otherwise. + */ + bool Has(std::string_view name) const { + return mArgs.find(name) != mArgs.end(); + } + + /** + * @brief Check if a flag has been set. + * @param name the name of the flag. + * @return true if the flag was set, false otherwise. + */ + bool FlagSet(std::string_view name) const { + return mFlags.find(name) != mFlags.end(); + } + + /** + * @brief Get the raw value of an argument. + * @param name the name of the argument. + * @return the value of the argument, as is. + */ + std::string_view Get(std::string_view name) const { + return mArgs.at(name); + } + + /** + * @brief Get the value of an argument with conversion. + * @tparam T the type to convert the argument to. + * @param name the name of the argument. + * @return the value of the argument after conversion. + * + * Specializations exist for this function for bool, + * int and std::size_t. It is possible to provide + * custom specializations that can be used to turn a string into any kind of + * object. + */ + template + T Get(std::string_view name) const; + + private: + ProgramOptions( + const std::unordered_map& args, + const std::unordered_map& flags) + : mArgs(args), mFlags(flags){}; + + std::unordered_map mArgs; + std::unordered_map mFlags; +}; + +/** + * @brief Specialization of CmdArgs::Get for bool. + */ +template <> +inline bool ProgramOptions::Get(std::string_view name) const { + const auto v = mArgs.at(name); + return v == "1" || v == "true"; +} + +/** + * @brief Specialization for CmdArgs::Get for int. + */ +template <> +inline int ProgramOptions::Get(std::string_view name) const { + return std::stoi(mArgs.at(name).data()); +} + +/** + * @brief Specialization of CmdArgs::Get for std::size_t. + */ +template <> +inline std::size_t ProgramOptions::Get( + std::string_view name) const { + return std::stoul(mArgs.at(name).data()); +} + +/** + * @brief An command-line argument definition. + */ +struct ProgramArg { + /** + * @brief Create a required command-line argument. + * @param name the name. + * @param type_hint a string describing the expected type. E.g., "int". + * @param description a short description. + */ + static ProgramArg Required(std::string_view name, + std::string_view type_hint, + std::string_view description = "") { + return ProgramArg{true, name, type_hint, description, {}}; + } + + /** + * @brief Create an optional command-line argument. + * @param name the name. + * @param type_hint a string describing the expected type. E.g., "int". + * @param default_value an optional default value. + * @param description a short description. + */ + static ProgramArg Optional(std::string_view name, + std::string_view type_hint, + std::optional default_value, + std::string_view description = "") { + return ProgramArg{false, name, type_hint, description, default_value}; + } + + /** + * @brief Whether this argument is required. + */ + bool required; + + /** + * @brief The name of this argument. + */ + std::string_view name; + + /** + * @brief A type hint. Only used as part of the description. + */ + std::string_view type_hint; + + /** + * @brief A short description of this argument. + */ + std::string_view description; + + /** + * @brief A default value. Ignored if \p required is true. + */ + std::optional default_value; +}; + +/** + * @brief A command-line argument flag definition. + */ +struct ProgramFlag { + /** + * @brief Create a flag argument. + * @param name the name of the flag. + * @param description a description. + */ + ProgramFlag(std::string_view name, std::string_view description = "") + : name(name), description(description) {} + + /** + * @brief The name. + */ + std::string_view name; + + /** + * @brief A short descruption. + */ + std::string_view description; +}; + +/** + * @brief Argument parser. + * + * The parser accepts argument defintions (through the Add functions) and + * parses the arguments provided to the main function into a CmdArgs object. + */ +class ProgramOptions::Parser { + public: + /** + * @brief Create a command-line argument parser. + * @param description a short description of the program. + */ + Parser(std::string_view description = "") : mDescription(description) {} + + /** + * @brief Define an argument. + * @param def an argument definition. + */ + Parser& Add(const ProgramArg& def) { + if (Exists(def)) { + PrintHelp("duplicate argument definition"); + } + mArgs.emplace_back(def); + return *this; + } + + /** + * @brief Define a flag argument. + * @param flag a flag definition. + */ + Parser& Add(const ProgramFlag& flag) { + if (Exists(flag)) { + PrintHelp("duplicate argument definition"); + } + mFlags.emplace_back(flag); + return *this; + } + + /** + * @brief Parse arguments. + * @param argc the number of arguments. + * @param argv the arguments. + * + * The \p argc and \p argv are assumed to be the inputs to a programs main + * function. + */ + ProgramOptions Parse(int argc, char* argv[]); + + /** + * @brief Print a help string to stdout. + */ + void Help() const { + ArgListLong(std::cout); + } + + private: + template + bool Exists(const T& arg_or_flag) const; + void ArgListShort(std::ostream& stream, std::string_view program_name) const; + void ArgListLong(std::ostream& stream) const; + + bool IsArg(std::string_view name) const; + bool IsFlag(std::string_view name) const; + + template + void ForEachOptional(const std::list& list, P pred) const { + std::for_each(list.begin(), list.end(), [&](const auto e) { + if (!e.required) { + pred(e); + } + }); + } + + template + void ForEachRequired(const std::list& list, P pred) const { + std::for_each(list.begin(), list.end(), [&](const auto e) { + if (e.required) { + pred(e); + } + }); + } + + void PrintHelp(std::string_view error_msg = ""); + + std::string_view mDescription; + std::string_view mProgramName; + + std::list mArgs; + std::list mFlags; +}; + +template +bool ProgramOptions::Parser::Exists(const T& arg_or_flag) const { + const auto exists_a = std::any_of(mArgs.begin(), mArgs.end(), [&](auto a) { + return a.name == arg_or_flag.name; + }); + if (exists_a) { + return true; + } + + const auto exists_f = std::any_of(mFlags.begin(), mFlags.end(), [&](auto a) { + return a.name == arg_or_flag.name; + }); + return exists_f; +} + +} // namespace scl::util + +#endif // SCL_UTIL_CMDLINE_H diff --git a/src/scl/math/number.cc b/src/scl/math/number.cc index 3959665..485f668 100644 --- a/src/scl/math/number.cc +++ b/src/scl/math/number.cc @@ -17,11 +17,14 @@ #include "scl/math/number.h" +#include #include #include #include #include +#include + scl::math::Number::Number() { mpz_init(m_value); } @@ -54,12 +57,35 @@ scl::math::Number scl::math::Number::Random(std::size_t bits, util::PRG& prg) { return r; } +scl::math::Number scl::math::Number::RandomPrime(std::size_t bits, + util::PRG& prg) { + auto r = Random(bits, prg); + Number prime; + mpz_nextprime(prime.m_value, r.m_value); + return prime; +} + scl::math::Number scl::math::Number::FromString(const std::string& str) { scl::math::Number num; mpz_set_str(num.m_value, str.c_str(), 16); return num; } // LCOV_EXCL_LINE +scl::math::Number scl::math::Number::Read(const unsigned char* buf) { + std::uint32_t size_and_sign; + std::memcpy(&size_and_sign, buf, sizeof(std::uint32_t)); + + bool negative = (size_and_sign >> 31) == 1; + auto size = size_and_sign & ((1 << 30) - 1); + + Number r; + mpz_import(r.m_value, size, 1, 1, 0, 0, buf + sizeof(std::uint32_t)); + if (negative) { + mpz_neg(r.m_value, r.m_value); + } + return r; +} // LCOV_EXCL_LINE + scl::math::Number::Number(int value) : Number() { mpz_set_si(m_value, value); } @@ -97,6 +123,12 @@ scl::math::Number scl::math::Number::operator/(const Number& number) const { return frac; } // LCOV_EXCL_LINE +scl::math::Number scl::math::Number::operator%(const Number& mod) const { + scl::math::Number res; + mpz_mod(res.m_value, m_value, mod.m_value); + return res; +} // LCOV_EXCL_LINE + scl::math::Number scl::math::Number::operator<<(int shift) const { scl::math::Number shifted; if (shift < 0) { @@ -145,6 +177,10 @@ int scl::math::Number::Compare(const Number& number) const { return mpz_cmp(m_value, number.m_value); } +std::size_t scl::math::Number::ByteSize() const { + return (BitSize() - 1) / 8 + 1; +} + std::size_t scl::math::Number::BitSize() const { return mpz_sizeinbase(m_value, 2); } @@ -161,3 +197,48 @@ std::string scl::math::Number::ToString() const { free(cstr); return ss.str(); } + +void scl::math::Number::Write(unsigned char* buf) const { + std::uint32_t size_and_sign = ByteSize(); + + if (mpz_sgn(m_value) < 0) { + size_and_sign |= (1 << 31); + } + + std::memcpy(buf, &size_and_sign, sizeof(std::uint32_t)); + mpz_export(buf + sizeof(std::uint32_t), NULL, 1, 1, 0, 0, m_value); +} + +scl::math::Number scl::math::LCM(const Number& a, const Number& b) { + Number lcm; + mpz_lcm(lcm.m_value, a.m_value, b.m_value); + return lcm; +} // LCOV_EXCL_LINE + +scl::math::Number scl::math::GCD(const Number& a, const Number& b) { + Number gcd; + mpz_gcd(gcd.m_value, a.m_value, b.m_value); + return gcd; +} // LCOV_EXCL_LINE + +scl::math::Number scl::math::ModInverse(const Number& val, const Number& mod) { + if (mpz_sgn(mod.m_value) == 0) { + throw std::invalid_argument("modulus cannot be 0"); + } + + Number inv; + auto e = mpz_invert(inv.m_value, val.m_value, mod.m_value); + if (e == 0) { + throw std::logic_error("number not invertible"); + } + + return inv; +} // LCOV_EXCL_LINE + +scl::math::Number scl::math::ModExp(const Number& base, + const Number& exp, + const Number& mod) { + Number r; + mpz_powm(r.m_value, base.m_value, exp.m_value, mod.m_value); + return r; +} // LCOV_EXCL_LINE diff --git a/src/scl/util/cmdline.cc b/src/scl/util/cmdline.cc new file mode 100644 index 0000000..123eced --- /dev/null +++ b/src/scl/util/cmdline.cc @@ -0,0 +1,199 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "scl/util/cmdline.h" + +#include + +namespace {} // namespace + +bool scl::util::ProgramOptions::Parser::IsArg(std::string_view name) const { + return std::any_of(mArgs.begin(), mArgs.end(), [&](auto a) { + return a.name == name; + }); +} + +bool scl::util::ProgramOptions::Parser::IsFlag(std::string_view name) const { + return std::any_of(mFlags.begin(), mFlags.end(), [&](auto f) { + return f.name == name; + }); +} + +namespace { + +bool Name(std::string_view opt_name, std::string_view& name) { + if (opt_name[0] != '-') { + return false; + } + name = opt_name.substr(1, opt_name.size()); + return true; +} + +} // namespace + +scl::util::ProgramOptions scl::util::ProgramOptions::Parser::Parse( + int argc, + char** argv) { + mProgramName = argv[0]; + std::vector cmd_args(argv + 1, argv + argc); + + const auto help_needed = std::any_of(cmd_args.begin(), + cmd_args.end(), + [](auto e) { return e == "-help"; }); + if (help_needed) { + PrintHelp(); + } + + std::unordered_map args; + std::for_each(mArgs.begin(), mArgs.end(), [&args](const auto arg) { + if (arg.default_value.has_value()) { + args[arg.name] = arg.default_value.value(); + } + }); + + std::unordered_map flags; + std::size_t i = 0; + while (i < cmd_args.size()) { + std::string_view name; + if (!Name(cmd_args[i++], name)) { + PrintHelp("argument must begin with '-'"); + } + + if (IsArg(name)) { + if (i == cmd_args.size()) { + PrintHelp("invalid argument"); + } + args[name] = cmd_args[i++]; + } else if (IsFlag(name)) { + flags[name] = true; + } else { + PrintHelp("encountered unknown argument"); + } + } + + // check if we got everything + ForEachRequired(mArgs, [&](const auto arg) { + if (args.find(arg.name) == args.end()) { + PrintHelp("missing required argument"); + } + }); + + return ProgramOptions(args, flags); +} + +void scl::util::ProgramOptions::Parser::ArgListShort( + std::ostream& stream, + std::string_view program_name) const { + stream << "Usage: " << program_name << " "; + ForEachRequired(mArgs, [&stream](const auto arg) { + stream << "-" << arg.name << " " << arg.type_hint << " "; + }); + + stream << "[options ...]" << std::endl; +} + +std::string GetPadding(std::size_t lead) { + const static std::size_t padding = 20; + const static std::size_t min_padding = 5; + const auto psz = lead >= padding + min_padding ? min_padding : padding - lead; + return std::string(psz, ' '); +} + +void WriteArg(std::ostream& stream, const scl::util::ProgramArg& arg) { + stream << " -" << arg.name << " '" << arg.type_hint << "'"; + if (!arg.description.empty()) { + const auto pad_str = GetPadding(arg.name.size() + arg.type_hint.size() + 5); + stream << pad_str << arg.description << ". "; + } + if (arg.default_value.has_value()) { + stream << " [default=" << arg.default_value.value() << "]"; + } + stream << std::endl; +} + +void WriteFlag(std::ostream& stream, const scl::util::ProgramFlag& flag) { + stream << " -" << flag.name; + if (!flag.description.empty()) { + const auto pad_str = GetPadding(flag.name.size() + 2); + stream << pad_str << flag.description << ". "; + } + stream << std::endl; +} + +template +bool HasRequired(It begin, It end) { + return std::any_of(begin, end, [](const auto a) { return a.required; }); +} + +template +bool HasOptional(It begin, It end) { + return std::any_of(begin, end, [](const auto a) { return !a.required; }); +} + +void scl::util::ProgramOptions::Parser::ArgListLong( + std::ostream& stream) const { + if (!mDescription.empty()) { + stream << std::endl << mDescription << std::endl; + } + stream << std::endl; + + const auto has_req_arg = HasRequired(mArgs.begin(), mArgs.end()); + + if (has_req_arg) { + stream << "Required arguments" << std::endl; + ForEachRequired(mArgs, [&stream](const auto a) { WriteArg(stream, a); }); + stream << std::endl; + } + + if (HasOptional(mArgs.begin(), mArgs.end())) { + stream << "Optional Arguments" << std::endl; + + ForEachOptional(mArgs, [&stream](const auto a) { WriteArg(stream, a); }); + stream << std::endl; + } + + if (!mFlags.empty()) { + stream << "Flags" << std::endl; + std::for_each(mFlags.begin(), mFlags.end(), [&stream](const auto a) { + WriteFlag(stream, a); + }); + stream << std::endl; + } +} + +void scl::util::ProgramOptions::Parser::PrintHelp(std::string_view error_msg) { + bool error = !error_msg.empty(); + + if (error) { + std::cerr << "ERROR: " << error_msg << std::endl; + } + + if (!mProgramName.empty()) { + ArgListShort(std::cout, mProgramName); + } + ArgListLong(std::cout); + +#ifdef SCL_UTIL_NO_EXIT_ON_ERROR + + throw std::runtime_error(error ? "bad" : "good"); + +#else + + std::exit(error ? 1 : 0); + +#endif +} diff --git a/src/scl/util/prg.cc b/src/scl/util/prg.cc index 67e5c32..9ce3778 100644 --- a/src/scl/util/prg.cc +++ b/src/scl/util/prg.cc @@ -19,6 +19,7 @@ #include #include +#include #include diff --git a/test/scl/math/test_number.cc b/test/scl/math/test_number.cc index 3f7cbdf..8468d3c 100644 --- a/test/scl/math/test_number.cc +++ b/test/scl/math/test_number.cc @@ -17,6 +17,7 @@ #include #include +#include #include "scl/math/number.h" #include "scl/util/prg.h" @@ -49,6 +50,9 @@ TEST_CASE("Number create", "[math]") { REQUIRE(r1.ToString() == "Number{10584d2a1c30fa50d}"); REQUIRE(r1.BitSize() == 65); + Number p = Number::RandomPrime(10, prg); + REQUIRE(p.ToString() == "Number{133}"); // 307 + std::stringstream ss; ss << r1; REQUIRE(ss.str() == r1.ToString()); @@ -201,6 +205,14 @@ TEST_CASE("Number division", "[math]") { Catch::Matchers::Message("division by 0")); } +TEST_CASE("Number modulus", "[math]") { + Number a(42); + Number b(10); + + REQUIRE(a % b == Number(2)); + REQUIRE(b % a == Number(10)); +} + TEST_CASE("Number bit-shift", "[math]") { Number a(44334); REQUIRE(a << 5 == Number(1418688)); @@ -298,3 +310,59 @@ TEST_CASE("Number test bit", "[math]") { REQUIRE(a.TestBit(4)); REQUIRE(a.TestBit(5)); } + +TEST_CASE("Number mod inverse invalid", "[math]") { + Number a(10); + REQUIRE_THROWS_MATCHES(math::ModInverse(a, Number(0)), + std::invalid_argument, + Catch::Matchers::Message("modulus cannot be 0")); + + REQUIRE_THROWS_MATCHES(math::ModInverse(a, Number(2)), + std::logic_error, + Catch::Matchers::Message("number not invertible")); +} + +TEST_CASE("Number read/write", "[math]") { + Number a(1234); + + REQUIRE(a.BitSize() == 11); + REQUIRE(a.ByteSize() == 2); + + auto buf = + std::make_unique(a.ByteSize() + sizeof(std::uint32_t)); + + a.Write(buf.get()); + REQUIRE(a == Number::Read(buf.get())); + + auto prg = util::PRG::Create("rw"); + REPEAT { + const auto x = Number::Random(100, prg); + auto bufx = + std::make_unique(x.ByteSize() + sizeof(std::uint32_t)); + x.Write(bufx.get()); + REQUIRE(x == Number::Read(bufx.get())); + } +} + +TEST_CASE("Number RSA example", "[math]") { + auto prg = util::PRG::Create("rsa"); + const auto p = Number::RandomPrime(512, prg); + const auto q = Number::RandomPrime(512, prg); + REQUIRE(p != q); + const auto n = p * q; + const auto lm = math::LCM(p - Number(1), q - Number(1)); + + const auto e = Number(0x10001); + REQUIRE(math::GCD(e, lm) == Number(1)); + + const auto d = math::ModInverse(e, lm); + REQUIRE((d * e) % lm == Number(1)); + + Number msg(1234); + + const auto ctxt = math::ModExp(msg, e, n); + REQUIRE(ctxt != msg); + + const auto ptxt = math::ModExp(ctxt, d, n); + REQUIRE(ptxt == msg); +} diff --git a/test/scl/serialization/test_serializer.cc b/test/scl/serialization/test_serializer.cc index 45428cc..4e82965 100644 --- a/test/scl/serialization/test_serializer.cc +++ b/test/scl/serialization/test_serializer.cc @@ -18,6 +18,8 @@ #include #include "scl/math/fp.h" +#include "scl/math/number.h" +#include "scl/serialization/serializer.h" #include "scl/serialization/serializers.h" using namespace scl; @@ -117,3 +119,32 @@ TEST_CASE("Serialization Vec", "[misc]") { REQUIRE(v == w); } + +TEST_CASE("Serialization number", "[misc]") { + using Sn = seri::Serializer; + + math::Number a(1234); + auto buf = std::make_unique(Sn::SizeOf(a)); + + Sn::Write(a, buf.get()); + math::Number b; + Sn::Read(b, buf.get()); + + REQUIRE(a == b); +} + +TEST_CASE("Serialization number vector", "[misc]") { + using Sn = seri::Serializer>; + + std::vector nums = {math::Number(22222123), + math::Number(123), + math::Number(-10)}; + + auto buf = std::make_unique(Sn::SizeOf(nums)); + Sn::Write(nums, buf.get()); + + std::vector r; + Sn::Read(r, buf.get()); + + REQUIRE(nums == r); +} diff --git a/test/scl/simulation/test_simulator.cc b/test/scl/simulation/test_simulator.cc index 896ca64..8583de6 100644 --- a/test/scl/simulation/test_simulator.cc +++ b/test/scl/simulation/test_simulator.cc @@ -248,8 +248,7 @@ TEST_CASE("Simulation odd/even iterations", "[sim]") { sim::DataMeasurement m_even; sim::DataMeasurement m_odd; - // Cannot use SECTION here as because it m_even somehow gets overwritten with - // garbage... + // Cannot use SECTION here as m_even somehow gets overwritten with garbage... { const auto creator = []() { @@ -327,8 +326,7 @@ TEST_CASE("Simulation receive out-of-order", "[sim]") { /** * @brief Two party protocol that uses HasData. * - * This protocol captures both failure cases for simulating a HasData call. - Both + * This protocol captures both failure cases for simulating a HasData call. Both * failure cases arise */ struct HasDataProtocol { diff --git a/test/scl/util/test_cmdline.cc b/test/scl/util/test_cmdline.cc new file mode 100644 index 0000000..f39dc5e --- /dev/null +++ b/test/scl/util/test_cmdline.cc @@ -0,0 +1,187 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2023 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include + +#include "scl/util/cmdline.h" + +using namespace scl; + +#define CAPTURE_START \ + std::stringstream scl_cerr_buf; \ + std::stringstream scl_cout_buf; \ + std::streambuf* scl_cerr = std::cerr.rdbuf(scl_cerr_buf.rdbuf()); \ + std::streambuf* scl_cout = std::cout.rdbuf(scl_cout_buf.rdbuf()) + +#define CAPTURE_END(output_cout, output_cerr) \ + auto(output_cout) = scl_cout_buf.str(); \ + auto(output_cerr) = scl_cerr_buf.str(); \ + std::cout.rdbuf(scl_cout); \ + std::cerr.rdbuf(scl_cerr) + +#define WITH_EXIT_0(expr) \ + REQUIRE_THROWS_MATCHES((expr), \ + std::runtime_error, \ + Catch::Matchers::Message("good")) + +#define WITH_EXIT_1(expr) \ + REQUIRE_THROWS_MATCHES((expr), \ + std::runtime_error, \ + Catch::Matchers::Message("bad")) + +TEST_CASE("Cmdline print help", "[util]") { + const char* argv[] = {"program", "-help"}; + + auto p = util::ProgramOptions::Parser("Program description.") + .Add(util::ProgramArg::Optional("x", "y", "default")) + .Add(util::ProgramArg::Required("a", "b", "arg description")) + .Add(util::ProgramFlag("w", "flag description")); + + CAPTURE_START; + + WITH_EXIT_0(p.Parse(2, (char**)argv)); + + CAPTURE_END(outc, oute); + + REQUIRE(oute.empty()); + + REQUIRE_THAT(outc, Catch::Matchers::StartsWith("Usage: program")); + REQUIRE_THAT(outc, Catch::Matchers::Contains("Program description.")); + REQUIRE_THAT(outc, Catch::Matchers::Contains("-x 'y'")); + REQUIRE_THAT(outc, Catch::Matchers::Contains("-a 'b'")); + REQUIRE_THAT(outc, Catch::Matchers::Contains("-w")); + REQUIRE_THAT(outc, Catch::Matchers::Contains("arg description.")); + REQUIRE_THAT(outc, Catch::Matchers::Contains("flag description.")); + REQUIRE_THAT(outc, Catch::Matchers::Contains("[default=default]")); +} + +TEST_CASE("Cmdline parse with error", "[util]") { + const char* argv[] = {"program", "-x"}; + + auto p = util::ProgramOptions::Parser{}; + + CAPTURE_START; + WITH_EXIT_1(p.Parse(2, (char**)argv)); + CAPTURE_END(outc, oute); + + REQUIRE_THAT(outc, Catch::Matchers::StartsWith("Usage: program")); + REQUIRE(oute == "ERROR: encountered unknown argument\n"); +} + +TEST_CASE("Cmdline parse missing required", "[util]") { + const char* argv[] = {"program"}; + auto p = + util::ProgramOptions::Parser{}.Add(util::ProgramArg::Required("x", "y")); + + CAPTURE_START; + WITH_EXIT_1(p.Parse(1, (char**)argv)); + CAPTURE_END(outc, oute); + + REQUIRE(oute == "ERROR: missing required argument\n"); +} + +TEST_CASE("Cmdline parse invalid argument", "[util]") { + const char* argv[] = {"program", "-x"}; + auto p = + util::ProgramOptions::Parser{}.Add(util::ProgramArg::Required("x", "y")); + + CAPTURE_START; + WITH_EXIT_1(p.Parse(2, (char**)argv)); + CAPTURE_END(outc, oute); + + REQUIRE(oute == "ERROR: invalid argument\n"); +} + +TEST_CASE("Cmdline parse invalid argument name", "[util]") { + const char* argv[] = {"program", "x"}; + auto p = + util::ProgramOptions::Parser{}.Add(util::ProgramArg::Required("x", "y")); + + CAPTURE_START; + WITH_EXIT_1(p.Parse(2, (char**)argv)); + CAPTURE_END(outc, oute); + + REQUIRE(oute == "ERROR: argument must begin with '-'\n"); +} + +TEST_CASE("Cmdline duplicate arg definition", "[util]") { + auto p = util::ProgramOptions::Parser{}.Add( + util::ProgramArg::Required("x", "int")); + + CAPTURE_START; + WITH_EXIT_1(p.Add(util::ProgramArg::Required("x", "int"))); + CAPTURE_END(outc, oute); + + REQUIRE(oute == "ERROR: duplicate argument definition\n"); +} + +TEST_CASE("Cmdline duplicate flag definition", "[util]") { + auto p = util::ProgramOptions::Parser{}.Add(util::ProgramFlag("x")); + + CAPTURE_START; + WITH_EXIT_1(p.Add(util::ProgramFlag("x"))); + CAPTURE_END(outc, oute); + + REQUIRE(oute == "ERROR: duplicate argument definition\n"); +} + +TEST_CASE("Cmdline parse duplicate arg", "[misc]") { + const char* argv[] = {"program", "-x", "1", "-x", "2"}; + auto p = util::ProgramOptions::Parser{} + .Add(util::ProgramArg::Required("x", "int")) + .Parse(5, (char**)argv); + REQUIRE(p.Get("x") == "2"); +} + +TEST_CASE("Cmdline arg", "[util]") { + const char* argv[] = {"program", "-x", "100", "-w", "600", "-b", "true"}; + auto p = util::ProgramOptions::Parser{} + .Add(util::ProgramArg::Required("x", "int")) + .Add(util::ProgramArg::Required("w", "ulong")) + .Add(util::ProgramArg::Required("b", "bool")) + .Add(util::ProgramArg::Optional("y", "long", "100")) + .Parse(7, (char**)argv); + + REQUIRE(p.Has("x")); + auto v = p.Get("x"); + REQUIRE(v == "100"); + auto w = p.Get("x"); + REQUIRE(w == 100); + + REQUIRE(p.Has("w")); + auto ww = p.Get("w"); + REQUIRE(ww == 600); + + REQUIRE(p.Has("b")); + REQUIRE(p.Get("b")); + + REQUIRE(p.Has("y")); + REQUIRE(p.Get("y") == 100); +} + +TEST_CASE("Cmdline flag", "[util]") { + const char* argv[] = {"program", "-f"}; + auto p = util::ProgramOptions::Parser{} + .Add(util::ProgramFlag("f")) + .Add(util::ProgramFlag("h")) + .Parse(2, (char**)argv); + + REQUIRE(p.FlagSet("f")); + REQUIRE_FALSE(p.FlagSet("h")); + REQUIRE_FALSE(p.FlagSet("g")); +}