Skip to content

Commit 035c4d6

Browse files
[3.14] gh-132983: Minor fixes and clean up for the _zstd module (GH-134930) (GH-134998)
(cherry picked from commit b595237) Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
1 parent 777fd49 commit 035c4d6

File tree

6 files changed

+166
-160
lines changed

6 files changed

+166
-160
lines changed

Lib/test/test_zstd.py

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,27 +1139,41 @@ def test_invalid_dict(self):
11391139
ZstdDecompressor(zd)
11401140

11411141
# wrong type
1142-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1143-
ZstdCompressor(zstd_dict=(zd, b'123'))
1144-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1142+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1143+
ZstdCompressor(zstd_dict=[zd, 1])
1144+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1145+
ZstdCompressor(zstd_dict=(zd, 1.0))
1146+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1147+
ZstdCompressor(zstd_dict=(zd,))
1148+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11451149
ZstdCompressor(zstd_dict=(zd, 1, 2))
1146-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1150+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11471151
ZstdCompressor(zstd_dict=(zd, -1))
1148-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1152+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11491153
ZstdCompressor(zstd_dict=(zd, 3))
1150-
1151-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1152-
ZstdDecompressor(zstd_dict=(zd, b'123'))
1153-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1154+
with self.assertRaises(OverflowError):
1155+
ZstdCompressor(zstd_dict=(zd, 2**1000))
1156+
with self.assertRaises(OverflowError):
1157+
ZstdCompressor(zstd_dict=(zd, -2**1000))
1158+
1159+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1160+
ZstdDecompressor(zstd_dict=[zd, 1])
1161+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1162+
ZstdDecompressor(zstd_dict=(zd, 1.0))
1163+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1164+
ZstdDecompressor((zd,))
1165+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11541166
ZstdDecompressor((zd, 1, 2))
1155-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1167+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11561168
ZstdDecompressor((zd, -1))
1157-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1169+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11581170
ZstdDecompressor((zd, 3))
1171+
with self.assertRaises(OverflowError):
1172+
ZstdDecompressor((zd, 2**1000))
1173+
with self.assertRaises(OverflowError):
1174+
ZstdDecompressor((zd, -2**1000))
11591175

11601176
def test_train_dict(self):
1161-
1162-
11631177
TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
11641178
ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
11651179

@@ -1240,18 +1254,37 @@ def test_train_dict_c(self):
12401254
# argument wrong type
12411255
with self.assertRaises(TypeError):
12421256
_zstd.train_dict({}, (), 100)
1257+
with self.assertRaises(TypeError):
1258+
_zstd.train_dict(bytearray(), (), 100)
12431259
with self.assertRaises(TypeError):
12441260
_zstd.train_dict(b'', 99, 100)
1261+
with self.assertRaises(TypeError):
1262+
_zstd.train_dict(b'', [], 100)
12451263
with self.assertRaises(TypeError):
12461264
_zstd.train_dict(b'', (), 100.1)
1265+
with self.assertRaises(TypeError):
1266+
_zstd.train_dict(b'', (99.1,), 100)
1267+
with self.assertRaises(ValueError):
1268+
_zstd.train_dict(b'abc', (4, -1), 100)
1269+
with self.assertRaises(ValueError):
1270+
_zstd.train_dict(b'abc', (2,), 100)
1271+
with self.assertRaises(ValueError):
1272+
_zstd.train_dict(b'', (99,), 100)
12471273

12481274
# size > size_t
12491275
with self.assertRaises(ValueError):
1250-
_zstd.train_dict(b'', (2**64+1,), 100)
1276+
_zstd.train_dict(b'', (2**1000,), 100)
1277+
with self.assertRaises(ValueError):
1278+
_zstd.train_dict(b'', (-2**1000,), 100)
12511279

12521280
# dict_size <= 0
12531281
with self.assertRaises(ValueError):
12541282
_zstd.train_dict(b'', (), 0)
1283+
with self.assertRaises(ValueError):
1284+
_zstd.train_dict(b'', (), -1)
1285+
1286+
with self.assertRaises(ZstdError):
1287+
_zstd.train_dict(b'', (), 1)
12551288

12561289
def test_finalize_dict_c(self):
12571290
with self.assertRaises(TypeError):
@@ -1260,22 +1293,51 @@ def test_finalize_dict_c(self):
12601293
# argument wrong type
12611294
with self.assertRaises(TypeError):
12621295
_zstd.finalize_dict({}, b'', (), 100, 5)
1296+
with self.assertRaises(TypeError):
1297+
_zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
12631298
with self.assertRaises(TypeError):
12641299
_zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
1300+
with self.assertRaises(TypeError):
1301+
_zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
12651302
with self.assertRaises(TypeError):
12661303
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
1304+
with self.assertRaises(TypeError):
1305+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
12671306
with self.assertRaises(TypeError):
12681307
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
12691308
with self.assertRaises(TypeError):
12701309
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
12711310

1311+
with self.assertRaises(ValueError):
1312+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
1313+
with self.assertRaises(ValueError):
1314+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
1315+
with self.assertRaises(ValueError):
1316+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)
1317+
12721318
# size > size_t
12731319
with self.assertRaises(ValueError):
1274-
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
1320+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
1321+
with self.assertRaises(ValueError):
1322+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)
12751323

12761324
# dict_size <= 0
12771325
with self.assertRaises(ValueError):
12781326
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
1327+
with self.assertRaises(ValueError):
1328+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
1329+
with self.assertRaises(OverflowError):
1330+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
1331+
with self.assertRaises(OverflowError):
1332+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)
1333+
1334+
with self.assertRaises(OverflowError):
1335+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
1336+
with self.assertRaises(OverflowError):
1337+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)
1338+
1339+
with self.assertRaises(ZstdError):
1340+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)
12791341

12801342
def test_train_buffer_protocol_samples(self):
12811343
def _nbytes(dat):

Modules/_zstd/_zstdmodule.c

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "Python.h"
88

99
#include "_zstdmodule.h"
10-
#include "zstddict.h"
1110

1211
#include <zstd.h> // ZSTD_*()
1312
#include <zdict.h> // ZDICT_*()
@@ -20,14 +19,52 @@ module _zstd
2019
#include "clinic/_zstdmodule.c.h"
2120

2221

22+
ZstdDict *
23+
_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype)
24+
{
25+
if (state == NULL) {
26+
return NULL;
27+
}
28+
29+
/* Check ZstdDict */
30+
if (PyObject_TypeCheck(dict, state->ZstdDict_type)) {
31+
return (ZstdDict*)dict;
32+
}
33+
34+
/* Check (ZstdDict, type) */
35+
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2
36+
&& PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type)
37+
&& PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
38+
{
39+
int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
40+
if (type == -1 && PyErr_Occurred()) {
41+
return NULL;
42+
}
43+
if (type == DICT_TYPE_DIGESTED
44+
|| type == DICT_TYPE_UNDIGESTED
45+
|| type == DICT_TYPE_PREFIX)
46+
{
47+
*ptype = type;
48+
return (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
49+
}
50+
}
51+
52+
/* Wrong type */
53+
PyErr_SetString(PyExc_TypeError,
54+
"zstd_dict argument should be a ZstdDict object.");
55+
return NULL;
56+
}
57+
2358
/* Format error message and set ZstdError. */
2459
void
25-
set_zstd_error(const _zstd_state* const state,
26-
error_type type, size_t zstd_ret)
60+
set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret)
2761
{
28-
char *msg;
62+
const char *msg;
2963
assert(ZSTD_isError(zstd_ret));
3064

65+
if (state == NULL) {
66+
return;
67+
}
3168
switch (type) {
3269
case ERR_DECOMPRESS:
3370
msg = "Unable to decompress Zstandard data: %s";
@@ -174,7 +211,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
174211
Py_ssize_t sizes_sum;
175212
Py_ssize_t i;
176213

177-
chunks_number = Py_SIZE(samples_sizes);
214+
chunks_number = PyTuple_GET_SIZE(samples_sizes);
178215
if ((size_t) chunks_number > UINT32_MAX) {
179216
PyErr_Format(PyExc_ValueError,
180217
"The number of samples should be <= %u.", UINT32_MAX);
@@ -188,20 +225,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
188225
return -1;
189226
}
190227

191-
sizes_sum = 0;
228+
sizes_sum = PyBytes_GET_SIZE(samples_bytes);
192229
for (i = 0; i < chunks_number; i++) {
193-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
194-
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
195-
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
196-
PyErr_Format(PyExc_ValueError,
197-
"Items in samples_sizes should be an int "
198-
"object, with a value between 0 and %u.", SIZE_MAX);
230+
size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
231+
(*chunk_sizes)[i] = size;
232+
if (size == (size_t)-1 && PyErr_Occurred()) {
233+
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
234+
goto sum_error;
235+
}
199236
return -1;
200237
}
201-
sizes_sum += (*chunk_sizes)[i];
238+
if ((size_t)sizes_sum < size) {
239+
goto sum_error;
240+
}
241+
sizes_sum -= size;
202242
}
203243

204-
if (sizes_sum != Py_SIZE(samples_bytes)) {
244+
if (sizes_sum != 0) {
245+
sum_error:
205246
PyErr_SetString(PyExc_ValueError,
206247
"The samples size tuple doesn't match the "
207248
"concatenation's size.");
@@ -257,7 +298,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
257298

258299
/* Train the dictionary */
259300
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
260-
char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
301+
const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
261302
Py_BEGIN_ALLOW_THREADS
262303
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
263304
samples_buffer,
@@ -507,17 +548,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
507548
{
508549
_zstd_state* mod_state = get_zstd_state(module);
509550

510-
if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
511-
PyErr_SetString(PyExc_ValueError,
512-
"The two arguments should be CompressionParameter and "
513-
"DecompressionParameter types.");
514-
return NULL;
515-
}
516-
517-
Py_XSETREF(
518-
mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type));
519-
Py_XSETREF(
520-
mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type));
551+
Py_INCREF(c_parameter_type);
552+
Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
553+
Py_INCREF(d_parameter_type);
554+
Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);
521555

522556
Py_RETURN_NONE;
523557
}
@@ -580,7 +614,6 @@ do { \
580614
return -1;
581615
}
582616
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
583-
Py_DECREF(mod_state->ZstdError);
584617
return -1;
585618
}
586619

Modules/_zstd/_zstdmodule.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#ifndef ZSTD_MODULE_H
66
#define ZSTD_MODULE_H
77

8+
#include "zstddict.h"
9+
810
/* Type specs */
911
extern PyType_Spec zstd_dict_type_spec;
1012
extern PyType_Spec zstd_compressor_type_spec;
@@ -43,10 +45,14 @@ typedef enum {
4345
DICT_TYPE_PREFIX = 2
4446
} dictionary_type;
4547

48+
extern ZstdDict *
49+
_Py_parse_zstd_dict(const _zstd_state *state,
50+
PyObject *dict, int *type);
51+
4652
/* Format error message and set ZstdError. */
4753
extern void
48-
set_zstd_error(const _zstd_state* const state,
49-
const error_type type, size_t zstd_ret);
54+
set_zstd_error(const _zstd_state *state,
55+
error_type type, size_t zstd_ret);
5056

5157
extern void
5258
set_parameter_error(int is_compress, int key_v, int value_v);

0 commit comments

Comments
 (0)