diff --git a/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino b/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino index 9221af1eaa2..7e292e1adfb 100644 --- a/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino +++ b/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino @@ -1,52 +1,59 @@ +/* +This example enables catch-all Captive portal for ESP32 Access-Point +It will allow modern devices/OSes to detect that WiFi connection is +limited and offer a user to access a banner web-page. +There is no need to find and open device's IP address/URL, i.e. http://192.168.4.1/ +This works for Android, Ubuntu, FireFox, Windows, maybe others... +*/ + +#include #include #include +#include + -const byte DNS_PORT = 53; -IPAddress apIP(8,8,4,4); // The default android DNS DNSServer dnsServer; -WiFiServer server(80); +WebServer server(80); + +static const char responsePortal[] = R"===( +ESP32 CaptivePortal +

Hello World!

This is a captive portal example page. All unknown http requests will +be redirected here.

+)==="; + +// index page handler +void handleRoot() { + server.send(200, "text/plain", "Hello from esp32!"); +} -String responseHTML = "" - "CaptivePortal" - "

Hello World!

This is a captive portal example. All requests will " - "be redirected here.

"; +// this will redirect unknown http req's to our captive portal page +// based on this redirect various systems could detect that WiFi AP has a captive portal page +void handleNotFound() { + server.sendHeader("Location", "/portal"); + server.send(302, "text/plain", "redirect to captive portal"); +} -void setup() { +void setup() { + Serial.begin(115200); WiFi.mode(WIFI_AP); WiFi.softAP("ESP32-DNSServer"); - WiFi.softAPConfig(apIP, apIP, IPAddress(255, 255, 255, 0)); - // if DNSServer is started with "*" for domain name, it will reply with - // provided IP to all DNS request - dnsServer.start(DNS_PORT, "*", apIP); + // by default DNSServer is started serving any "*" domain name. It will reply + // AccessPoint's IP to all DNS request (this is requred for Captive Portal detection) + dnsServer.start(); + + // serve a simple root page + server.on("/", handleRoot); + + // serve portal page + server.on("/portal",[](){server.send(200, "text/html", responsePortal);}); + // all unknown pages are redirected to captive portal + server.onNotFound(handleNotFound); server.begin(); } void loop() { - dnsServer.processNextRequest(); - WiFiClient client = server.available(); // listen for incoming clients - - if (client) { - String currentLine = ""; - while (client.connected()) { - if (client.available()) { - char c = client.read(); - if (c == '\n') { - if (currentLine.length() == 0) { - client.println("HTTP/1.1 200 OK"); - client.println("Content-type:text/html"); - client.println(); - client.print(responseHTML); - break; - } else { - currentLine = ""; - } - } else if (c != '\r') { - currentLine += c; - } - } - } - client.stop(); - } + server.handleClient(); + delay(5); // give CPU some idle time } diff --git a/libraries/DNSServer/src/DNSServer.cpp b/libraries/DNSServer/src/DNSServer.cpp index aaa3d33344b..a8114733460 100644 --- a/libraries/DNSServer/src/DNSServer.cpp +++ b/libraries/DNSServer/src/DNSServer.cpp @@ -1,6 +1,8 @@ #include "DNSServer.h" #include #include +#include + // #define DEBUG_ESP_DNS #ifdef DEBUG_ESP_PORT @@ -9,45 +11,37 @@ #define DEBUG_OUTPUT Serial #endif -DNSServer::DNSServer() -{ - _ttl = htonl(DNS_DEFAULT_TTL); - _errorReplyCode = DNSReplyCode::NonExistentDomain; - _dnsHeader = (DNSHeader*) malloc( sizeof(DNSHeader) ) ; - _dnsQuestion = (DNSQuestion*) malloc( sizeof(DNSQuestion) ) ; - _buffer = NULL; - _currentPacketSize = 0; - _port = 0; -} +#define DNS_MIN_REQ_LEN 17 // minimal size for DNS request asking ROOT = DNS_HEADER_SIZE + 1 null byte for Name + 4 bytes type/class -DNSServer::~DNSServer() -{ - if (_dnsHeader) { - free(_dnsHeader); - _dnsHeader = NULL; - } - if (_dnsQuestion) { - free(_dnsQuestion); - _dnsQuestion = NULL; - } - if (_buffer) { - free(_buffer); - _buffer = NULL; - } +DNSServer::DNSServer() : _port(DNS_DEFAULT_PORT), _ttl(htonl(DNS_DEFAULT_TTL)), _errorReplyCode(DNSReplyCode::NonExistentDomain){} + +DNSServer::DNSServer(const String &domainName) : _port(DNS_DEFAULT_PORT), _ttl(htonl(DNS_DEFAULT_TTL)), _errorReplyCode(DNSReplyCode::NonExistentDomain), _domainName(domainName){}; + + +bool DNSServer::start(){ + if (_resolvedIP.operator uint32_t() == 0){ // no address is set, try to obtain AP interface's IP + if (WiFi.getMode() & WIFI_AP){ + _resolvedIP = WiFi.softAPIP(); + } else return false; // won't run if WiFi is not in AP mode + } + + _udp.close(); + _udp.onPacket([this](AsyncUDPPacket& pkt){ this->_handleUDP(pkt); }); + return _udp.listen(_port); } -bool DNSServer::start(const uint16_t &port, const String &domainName, - const IPAddress &resolvedIP) -{ +bool DNSServer::start(uint16_t port, const String &domainName, const IPAddress &resolvedIP){ _port = port; - _buffer = NULL; - _domainName = domainName; - _resolvedIP[0] = resolvedIP[0]; - _resolvedIP[1] = resolvedIP[1]; - _resolvedIP[2] = resolvedIP[2]; - _resolvedIP[3] = resolvedIP[3]; - downcaseAndRemoveWwwPrefix(_domainName); - return _udp.begin(_port) == 1; + if (domainName != "*"){ + _domainName = domainName; + downcaseAndRemoveWwwPrefix(_domainName); + } else + _domainName.clear(); + + _resolvedIP = resolvedIP; + _udp.close(); + _udp.onPacket([this](AsyncUDPPacket& pkt){ this->_handleUDP(pkt); }); + return _udp.listen(_port); } void DNSServer::setErrorReplyCode(const DNSReplyCode &replyCode) @@ -62,9 +56,7 @@ void DNSServer::setTTL(const uint32_t &ttl) void DNSServer::stop() { - _udp.stop(); - free(_buffer); - _buffer = NULL; + _udp.close(); } void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName) @@ -73,151 +65,125 @@ void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName) domainName.replace("www.", ""); } -void DNSServer::processNextRequest() +void DNSServer::_handleUDP(AsyncUDPPacket& pkt) { - _currentPacketSize = _udp.parsePacket(); - if (_currentPacketSize) - { - // Allocate buffer for the DNS query - if (_buffer != NULL) - free(_buffer); - _buffer = (unsigned char*)malloc(_currentPacketSize * sizeof(char)); - if (_buffer == NULL) - return; + if (pkt.length() < DNS_MIN_REQ_LEN) return; // truncated packet or not a DNS req + + // get DNS header (beginning of message) + DNSHeader dnsHeader; + DNSQuestion dnsQuestion; + memcpy( &dnsHeader, pkt.data(), DNS_HEADER_SIZE ); + if (dnsHeader.QR != DNS_QR_QUERY) return; // ignore non-query mesages - // Put the packet received in the buffer and get DNS header (beginning of message) - // and the question - _udp.read(_buffer, _currentPacketSize); - memcpy( _dnsHeader, _buffer, DNS_HEADER_SIZE ) ; - if ( requestIncludesOnlyOneQuestion() ) + if ( requestIncludesOnlyOneQuestion(dnsHeader) ) { +/* // The QName has a variable length, maximum 255 bytes and is comprised of multiple labels. // Each label contains a byte to describe its length and the label itself. The list of // labels terminates with a zero-valued byte. In "github.com", we have two labels "github" & "com" - // Iterate through the labels and copy them as they come into a single buffer (for simplicity's sake) - _dnsQuestion->QNameLength = 0 ; - while ( _buffer[ DNS_HEADER_SIZE + _dnsQuestion->QNameLength ] != 0 ) - { - memcpy( (void*) &_dnsQuestion->QName[_dnsQuestion->QNameLength], (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength], _buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength] + 1 ) ; - _dnsQuestion->QNameLength += _buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength] + 1 ; - } - _dnsQuestion->QName[_dnsQuestion->QNameLength] = 0 ; - _dnsQuestion->QNameLength++ ; - - // Copy the QType and QClass - memcpy( &_dnsQuestion->QType, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength], sizeof(_dnsQuestion->QType) ) ; - memcpy( &_dnsQuestion->QClass, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength + sizeof(_dnsQuestion->QType)], sizeof(_dnsQuestion->QClass) ) ; +*/ + const char * enoflbls = strchr(reinterpret_cast(pkt.data()) + DNS_HEADER_SIZE, 0); // find end_of_label marker + ++enoflbls; // advance after null terminator + dnsQuestion.QName = pkt.data() + DNS_HEADER_SIZE; // we can reference labels from the request + dnsQuestion.QNameLength = enoflbls - (char*)pkt.data() - DNS_HEADER_SIZE; + /* + check if we aint going out of pkt bounds + proper dns req should have label terminator at least 4 bytes before end of packet + */ + if (dnsQuestion.QNameLength > pkt.length() - DNS_HEADER_SIZE - sizeof(dnsQuestion.QType) - sizeof(dnsQuestion.QClass)) return; // malformed packet + + // Copy the QType and QClass + memcpy( &dnsQuestion.QType, enoflbls, sizeof(dnsQuestion.QType) ); + memcpy( &dnsQuestion.QClass, enoflbls + sizeof(dnsQuestion.QType), sizeof(dnsQuestion.QClass) ); } - - if (_dnsHeader->QR == DNS_QR_QUERY && - _dnsHeader->OPCode == DNS_OPCODE_QUERY && - requestIncludesOnlyOneQuestion() && - (_domainName == "*" || getDomainNameWithoutWwwPrefix() == _domainName) + // will reply with IP only to "*" or if doman matches without www. subdomain + if (dnsHeader.OPCode == DNS_OPCODE_QUERY && + requestIncludesOnlyOneQuestion(dnsHeader) && + (_domainName.isEmpty() || + getDomainNameWithoutWwwPrefix(static_cast(dnsQuestion.QName), dnsQuestion.QNameLength) == _domainName) ) { - replyWithIP(); - } - else if (_dnsHeader->QR == DNS_QR_QUERY) - { - replyWithCustomCode(); + replyWithIP(pkt, dnsHeader, dnsQuestion); + return; } - free(_buffer); - _buffer = NULL; - } + // otherwise reply with custom code + replyWithCustomCode(pkt, dnsHeader); } -bool DNSServer::requestIncludesOnlyOneQuestion() +bool DNSServer::requestIncludesOnlyOneQuestion(DNSHeader& dnsHeader) { - return ntohs(_dnsHeader->QDCount) == 1 && - _dnsHeader->ANCount == 0 && - _dnsHeader->NSCount == 0 && - _dnsHeader->ARCount == 0; + return ntohs(dnsHeader.QDCount) == 1 && + dnsHeader.ANCount == 0 && + dnsHeader.NSCount == 0 && + dnsHeader.ARCount == 0; } -String DNSServer::getDomainNameWithoutWwwPrefix() +String DNSServer::getDomainNameWithoutWwwPrefix(const unsigned char* start, size_t len) { - // Error checking : if the buffer containing the DNS request is a null pointer, return an empty domain - String parsedDomainName = ""; - if (_buffer == NULL) - return parsedDomainName; - - // Set the start of the domain just after the header (12 bytes). If equal to null character, return an empty domain - unsigned char *start = _buffer + DNS_OFFSET_DOMAIN_NAME; - if (*start == 0) - { - return parsedDomainName; - } + String parsedDomainName(start, --len); // exclude trailing null byte from labels length, String constructor will add it anyway int pos = 0; - while(true) + while(posQR = DNS_QR_RESPONSE; - _dnsHeader->ANCount = _dnsHeader->QDCount; - _udp.write( (unsigned char*) _dnsHeader, DNS_HEADER_SIZE ) ; + dnsHeader.QR = DNS_QR_RESPONSE; + dnsHeader.ANCount = dnsHeader.QDCount; + rpl.write( (unsigned char*) &dnsHeader, DNS_HEADER_SIZE ) ; // Write the question - _udp.write(_dnsQuestion->QName, _dnsQuestion->QNameLength) ; - _udp.write( (unsigned char*) &_dnsQuestion->QType, 2 ) ; - _udp.write( (unsigned char*) &_dnsQuestion->QClass, 2 ) ; + rpl.write(dnsQuestion.QName, dnsQuestion.QNameLength) ; + rpl.write( (uint8_t*) &dnsQuestion.QType, 2 ) ; + rpl.write( (uint8_t*) &dnsQuestion.QClass, 2 ) ; // Write the answer // Use DNS name compression : instead of repeating the name in this RNAME occurence, // set the two MSB of the byte corresponding normally to the length to 1. The following // 14 bits must be used to specify the offset of the domain name in the message // (<255 here so the first byte has the 6 LSB at 0) - _udp.write((uint8_t) 0xC0); - _udp.write((uint8_t) DNS_OFFSET_DOMAIN_NAME); + rpl.write((uint8_t) 0xC0); + rpl.write((uint8_t) DNS_OFFSET_DOMAIN_NAME); // DNS type A : host address, DNS class IN for INternet, returning an IPv4 address uint16_t answerType = htons(DNS_TYPE_A), answerClass = htons(DNS_CLASS_IN), answerIPv4 = htons(DNS_RDLENGTH_IPV4) ; - _udp.write((unsigned char*) &answerType, 2 ); - _udp.write((unsigned char*) &answerClass, 2 ); - _udp.write((unsigned char*) &_ttl, 4); // DNS Time To Live - _udp.write((unsigned char*) &answerIPv4, 2 ); - _udp.write(_resolvedIP, sizeof(_resolvedIP)); // The IP address to return - _udp.endPacket(); + rpl.write((unsigned char*) &answerType, 2 ); + rpl.write((unsigned char*) &answerClass, 2 ); + rpl.write((unsigned char*) &_ttl, 4); // DNS Time To Live + rpl.write((unsigned char*) &answerIPv4, 2 ); + uint32_t ip = _resolvedIP; + rpl.write(reinterpret_cast(&ip), sizeof(uint32_t)); // The IPv4 address to return + + _udp.sendTo(rpl, req.remoteIP(), req.remotePort()); #ifdef DEBUG_ESP_DNS DEBUG_OUTPUT.printf("DNS responds: %s for %s\n", - IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix().c_str() ); + _resolvedIP.toString().c_str(), getDomainNameWithoutWwwPrefix(static_cast(dnsQuestion.QName), dnsQuestion.QNameLength).c_str() ); #endif } -void DNSServer::replyWithCustomCode() +void DNSServer::replyWithCustomCode(AsyncUDPPacket& req, DNSHeader& dnsHeader) { - _dnsHeader->QR = DNS_QR_RESPONSE; - _dnsHeader->RCode = (unsigned char)_errorReplyCode; - _dnsHeader->QDCount = 0; + dnsHeader.QR = DNS_QR_RESPONSE; + dnsHeader.RCode = static_cast(_errorReplyCode); + dnsHeader.QDCount = 0; - _udp.beginPacket(_udp.remoteIP(), _udp.remotePort()); - _udp.write((unsigned char*)_dnsHeader, sizeof(DNSHeader)); - _udp.endPacket(); + AsyncUDPMessage rpl(sizeof(DNSHeader)); + rpl.write(reinterpret_cast(&dnsHeader), sizeof(DNSHeader)); + _udp.sendTo(rpl, req.remoteIP(), req.remotePort()); } diff --git a/libraries/DNSServer/src/DNSServer.h b/libraries/DNSServer/src/DNSServer.h index 1250f5ce960..860507ac9ec 100644 --- a/libraries/DNSServer/src/DNSServer.h +++ b/libraries/DNSServer/src/DNSServer.h @@ -1,15 +1,15 @@ -#ifndef DNSServer_h -#define DNSServer_h -#include +#pragma once +#include #define DNS_QR_QUERY 0 #define DNS_QR_RESPONSE 1 #define DNS_OPCODE_QUERY 0 #define DNS_DEFAULT_TTL 60 // Default Time To Live : time interval in seconds that the resource record should be cached before being discarded -#define DNS_OFFSET_DOMAIN_NAME 12 // Offset in bytes to reach the domain name in the DNS message #define DNS_HEADER_SIZE 12 +#define DNS_OFFSET_DOMAIN_NAME DNS_HEADER_SIZE // Offset in bytes to reach the domain name labels in the DNS message +#define DNS_DEFAULT_PORT 53 -enum class DNSReplyCode +enum class DNSReplyCode:uint16_t { NoError = 0, FormError = 1, @@ -59,14 +59,14 @@ struct DNSHeader uint16_t Flags; }; uint16_t QDCount; // number of question entries - uint16_t ANCount; // number of answer entries + uint16_t ANCount; // number of ANswer entries uint16_t NSCount; // number of authority entries - uint16_t ARCount; // number of resource entries + uint16_t ARCount; // number of Additional Resource entries }; struct DNSQuestion { - uint8_t QName[256] ; //need 1 Byte for zero termination! + const uint8_t* QName; uint16_t QNameLength ; uint16_t QType ; uint16_t QClass ; @@ -75,36 +75,116 @@ struct DNSQuestion class DNSServer { public: + /** + * @brief Construct a new DNSServer object + * by default server is configured to run in "Captive-portal" mode + * it must be started with start() call to establish a listening socket + * + */ DNSServer(); - ~DNSServer(); - void processNextRequest(); + + /** + * @brief Construct a new DNSServer object + * builds DNS server with default parameters + * @param domainName - domain name to serve + */ + DNSServer(const String &domainName); + ~DNSServer(){}; // default d-tor + + // Copy semantics not implemented (won't run on same UDP port anyway) + DNSServer(const DNSServer&) = delete; + DNSServer& operator=(const DNSServer&) = delete; + + + /** + * @brief stub, left for compatibility with an old version + * does nothing actually + * + */ + void processNextRequest(){}; + + /** + * @brief Set the Error Reply Code for all req's not matching predifined domain + * + * @param replyCode + */ void setErrorReplyCode(const DNSReplyCode &replyCode); + + /** + * @brief set TTL for successfull replies + * + * @param ttl in seconds + */ void setTTL(const uint32_t &ttl); - // Returns true if successful, false if there are no sockets available - bool start(const uint16_t &port, + /** + * @brief (re)Starts a server with current configuration or with default parameters + * if it's the first call. + * Defaults are: + * port: 53 + * domainName: any + * ip: WiFi AP's IP address + * + * @return true on success + * @return false if IP or socket error + */ + bool start(); + + /** + * @brief (re)Starts a server with provided configuration + * + * @return true on success + * @return false if IP or socket error + */ + bool start(uint16_t port, const String &domainName, const IPAddress &resolvedIP); - // stops the DNS server + + /** + * @brief stops the server and close UDP socket + * + */ void stop(); + /** + * @brief returns true if DNS server runs in captive-portal mode + * i.e. all requests are served with AP's ip address + * + * @return true if catch-all mode active + * @return false otherwise + */ + inline bool isCaptive() const { return _domainName.isEmpty(); }; + + /** + * @brief returns 'true' if server is up and UDP socket is listening for UDP req's + * + * @return true if server is up + * @return false otherwise + */ + inline bool isUp() { return _udp.connected(); }; + private: - WiFiUDP _udp; + AsyncUDP _udp; uint16_t _port; - String _domainName; - unsigned char _resolvedIP[4]; - int _currentPacketSize; - unsigned char* _buffer; - DNSHeader* _dnsHeader; uint32_t _ttl; DNSReplyCode _errorReplyCode; - DNSQuestion* _dnsQuestion ; + String _domainName; + IPAddress _resolvedIP; void downcaseAndRemoveWwwPrefix(String &domainName); - String getDomainNameWithoutWwwPrefix(); - bool requestIncludesOnlyOneQuestion(); - void replyWithIP(); - void replyWithCustomCode(); + + /** + * @brief Get the Domain Name Without Www Prefix object + * scan labels in DNS packet and build a string of a domain name + * truncate any www. label if found + * @param start a pointer to the start of labels records in DNS packet + * @param len labels length + * @return String + */ + String getDomainNameWithoutWwwPrefix(const unsigned char* start, size_t len); + inline bool requestIncludesOnlyOneQuestion(DNSHeader& dnsHeader); + void replyWithIP(AsyncUDPPacket& req, DNSHeader& dnsHeader, DNSQuestion& dnsQuestion); + inline void replyWithCustomCode(AsyncUDPPacket& req, DNSHeader& dnsHeader); + void _handleUDP(AsyncUDPPacket& pkt); }; -#endif