Skip to content

Commit 22d9c46

Browse files
added downloadAndDecompress function
1 parent de12663 commit 22d9c46

File tree

2 files changed

+141
-30
lines changed

2 files changed

+141
-30
lines changed

src/Arduino_Portenta_OTA.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class Arduino_Portenta_OTA
103103
*/
104104
int download(const char * url, bool const is_https, MbedSocketClass * socket = static_cast<MbedSocketClass*>(&WiFi));
105105
int decompress();
106+
int downloadAndDecompress(const char * url, bool const is_https, MbedSocketClass * socket = static_cast<MbedSocketClass*>(&WiFi));
107+
106108
void setFeedWatchdogFunc(ArduinoPortentaOtaWatchdogResetFuncPointer func);
107109
void feedWatchdog();
108110

src/decompress/utility.cpp

Lines changed: 139 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,149 @@ uint32_t crc_update(uint32_t crc, const void * data, size_t data_len)
8989
MAIN
9090
**************************************************************************************/
9191

92+
union HeaderVersion
93+
{
94+
struct __attribute__((packed))
95+
{
96+
uint32_t header_version : 6;
97+
uint32_t compression : 1;
98+
uint32_t signature : 1;
99+
uint32_t spare : 4;
100+
uint32_t payload_target : 4;
101+
uint32_t payload_major : 8;
102+
uint32_t payload_minor : 8;
103+
uint32_t payload_patch : 8;
104+
uint32_t payload_build_num : 24;
105+
} field;
106+
uint8_t buf[sizeof(field)];
107+
static_assert(sizeof(buf) == 8, "Error: sizeof(HEADER.VERSION) != 8");
108+
};
109+
110+
union OTAHeader
111+
{
112+
struct __attribute__((packed))
113+
{
114+
uint32_t len;
115+
uint32_t crc32;
116+
uint32_t magic_number;
117+
HeaderVersion hdr_version;
118+
} header;
119+
uint8_t buf[sizeof(header)];
120+
static_assert(sizeof(buf) == 20, "Error: sizeof(HEADER) != 20");
121+
};
122+
92123
int Arduino_Portenta_OTA::download(const char * url, bool const is_https, MbedSocketClass * socket)
93124
{
94125
return socket->download((char *)url, UPDATE_FILE_NAME_LZSS, is_https);
95126
}
96127

128+
int Arduino_Portenta_OTA::downloadAndDecompress(const char * url, bool const is_https, MbedSocketClass * socket) {
129+
int res=0;
130+
131+
FILE* decompressed = fopen(UPDATE_FILE_NAME, "wb");
132+
OTAHeader ota_header;
133+
134+
LZSSDecoder decoder([&decompressed](const uint8_t c) {
135+
fwrite(&c, 1, 1, decompressed);
136+
});
137+
138+
enum OTA_DOWNLOAD_STATE: uint8_t {
139+
OTA_DOWNLOAD_HEADER=0,
140+
OTA_DOWNLOAD_FILE,
141+
OTA_DOWNLOAD_ERR
142+
};
143+
144+
// since mbed::Callback requires a function to not exceed a certain size, we group the following parameters in a struct
145+
struct {
146+
uint32_t crc32 = 0xFFFFFFFF;
147+
uint32_t header_copied_bytes = 0;
148+
OTA_DOWNLOAD_STATE state=OTA_DOWNLOAD_HEADER;
149+
} ota_progress;
150+
151+
int bytes = socket->download(url, is_https, [&decoder, &ota_header, &ota_progress](const char* buffer, uint32_t size) {
152+
for(char* cursor=(char*)buffer; cursor<buffer+size; ) {
153+
switch(ota_progress.state) {
154+
case OTA_DOWNLOAD_HEADER: {
155+
// read to ota_header.buf
156+
// the header could be split into two arrivals, we must handle that
157+
uint32_t copied = size < sizeof(ota_header.buf) ? size : sizeof(ota_header.buf);
158+
memcpy(ota_header.buf, buffer, copied);
159+
cursor += copied;
160+
ota_progress.header_copied_bytes += copied;
161+
162+
// when finished go to next state
163+
if(sizeof(ota_header.buf) == ota_progress.header_copied_bytes) {
164+
ota_progress.state = OTA_DOWNLOAD_FILE;
165+
166+
ota_progress.crc32 = crc_update(
167+
ota_progress.crc32,
168+
&(ota_header.header.magic_number),
169+
sizeof(ota_header) - offsetof(OTAHeader, header.magic_number)
170+
);
171+
172+
}
173+
break;
174+
}
175+
case OTA_DOWNLOAD_FILE:
176+
// continue to download the payload, decompressing it and calculate crc
177+
decoder.decompress((uint8_t*)cursor, size - (cursor-buffer));
178+
ota_progress.crc32 = crc_update(
179+
ota_progress.crc32,
180+
cursor,
181+
size - (cursor-buffer)
182+
);
183+
184+
cursor += size - (cursor-buffer);
185+
break;
186+
default:
187+
ota_progress.state = OTA_DOWNLOAD_ERR;
188+
}
189+
}
190+
});
191+
192+
// if download fails it return a negative error code
193+
if(bytes <= 0) {
194+
res = bytes;
195+
goto exit;
196+
}
197+
198+
// if state is download finished and completed correctly the state should be OTA_DOWNLOAD_FILE
199+
if(ota_progress.state != OTA_DOWNLOAD_FILE) {
200+
res = static_cast<int>(Error::OtaDownload);
201+
goto exit;
202+
}
203+
204+
if(ota_header.header.len == (bytes-sizeof(ota_header.buf))) {
205+
res = static_cast<int>(Error::OtaHeaderLength);
206+
goto exit;
207+
}
208+
209+
// verify magic number: it may be done in the download function and stop the download immediately
210+
if(ota_header.header.magic_number != ARDUINO_PORTENTA_OTA_MAGIC) {
211+
res = static_cast<int>(Error::OtaHeaterMagicNumber);
212+
goto exit;
213+
}
214+
215+
// finalize CRC and verify it
216+
ota_progress.crc32 ^= 0xFFFFFFFF;
217+
if(ota_header.header.crc32 != ota_progress.crc32) {
218+
res = static_cast<int>(Error::OtaHeaderCrc);
219+
goto exit;
220+
}
221+
222+
res = ftell(decompressed);
223+
224+
exit:
225+
fclose(decompressed);
226+
227+
if(res < 0) {
228+
remove(UPDATE_FILE_NAME);
229+
}
230+
231+
return res;
232+
}
233+
234+
97235
int Arduino_Portenta_OTA::decompress()
98236
{
99237
struct stat stat_buf;
@@ -103,36 +241,7 @@ int Arduino_Portenta_OTA::decompress()
103241
/* For UPDATE.BIN.LZSS - LZSS compressed binary files. */
104242
FILE* update_file = fopen(UPDATE_FILE_NAME_LZSS, "rb");
105243

106-
union HeaderVersion
107-
{
108-
struct __attribute__((packed))
109-
{
110-
uint32_t header_version : 6;
111-
uint32_t compression : 1;
112-
uint32_t signature : 1;
113-
uint32_t spare : 4;
114-
uint32_t payload_target : 4;
115-
uint32_t payload_major : 8;
116-
uint32_t payload_minor : 8;
117-
uint32_t payload_patch : 8;
118-
uint32_t payload_build_num : 24;
119-
} field;
120-
uint8_t buf[sizeof(field)];
121-
static_assert(sizeof(buf) == 8, "Error: sizeof(HEADER.VERSION) != 8");
122-
};
123-
124-
union
125-
{
126-
struct __attribute__((packed))
127-
{
128-
uint32_t len;
129-
uint32_t crc32;
130-
uint32_t magic_number;
131-
HeaderVersion hdr_version;
132-
} header;
133-
uint8_t buf[sizeof(header)];
134-
static_assert(sizeof(buf) == 20, "Error: sizeof(HEADER) != 20");
135-
} ota_header;
244+
OTAHeader ota_header;
136245
uint32_t crc32, bytes_read;
137246
uint8_t crc_buf[128];
138247

0 commit comments

Comments
 (0)