@@ -49,7 +49,8 @@ CREATE OR REPLACE FUNCTION aws_s3.table_import_from_s3 (
49
49
secret_key text default null ,
50
50
session_token text default null ,
51
51
endpoint_url text default null ,
52
- content_encoding text default null
52
+ read_timeout integer default 60 ,
53
+ tempfile_dir text default ' /var/lib/postgresql/data/'
53
54
) RETURNS int
54
55
LANGUAGE plpython3u
55
56
AS $$
@@ -86,33 +87,51 @@ AS $$
86
87
s3 = boto3 .resource (
87
88
' s3' ,
88
89
region_name= region,
90
+ config= boto3 .session .Config(read_timeout= read_timeout),
89
91
** aws_settings
90
92
)
91
93
92
- obj = s3 .Object (bucket, file_path)
93
- response = obj .get ()
94
- content_encoding = content_encoding or response .get (' ContentEncoding' )
95
- user_content_encoding = response .get (' x-amz-meta-content-encoding' )
96
- body = response[' Body' ]
97
-
98
- with tempfile .NamedTemporaryFile () as fd:
99
- if (content_encoding and content_encoding .lower () == ' gzip' ) or (user_content_encoding and user_content_encoding .lower () == ' gzip' ):
100
- with gzip .GzipFile (fileobj= body) as gzipfile:
101
- while fd .write (gzipfile .read (204800 )):
102
- pass
103
- else:
104
- while fd .write (body .read (204800 )):
105
- pass
106
- fd .flush ()
107
- formatted_column_list = " ({column_list})" .format(column_list= column_list) if column_list else ' '
108
- res = plpy .execute (" COPY {table_name} {formatted_column_list} FROM {filename} {options};" .format(
109
- table_name= table_name,
110
- filename= plpy .quote_literal (fd .name ),
111
- formatted_column_list= formatted_column_list,
112
- options= options
113
- )
114
- )
115
- return res .nrows ()
94
+ formatted_column_list = " ({column_list})" .format(column_list= column_list) if column_list else ' '
95
+ num_rows = 0
96
+
97
+ for file_path_item in file_path .split (" ," ):
98
+ file_path_item = file_path_item .strip ()
99
+ if not file_path_item:
100
+ continue
101
+
102
+ s3_objects = []
103
+ if file_path_item .endswith (" /" ): # Directory
104
+ bucket_objects = s3 .Bucket (bucket).objects .filter (Prefix= file_path_item)
105
+ s3_objects = [bucket_object for bucket_object in bucket_objects]
106
+ else: # File
107
+ s3_object = s3 .Object (bucket, file_path_item)
108
+ s3_objects = [s3_object]
109
+
110
+ for s3_object in s3_objects:
111
+ response = s3_object .get ()
112
+ content_encoding = response .get (' ContentEncoding' )
113
+ body = response[' Body' ]
114
+ user_content_encoding = response .get (' x-amz-meta-content-encoding' )
115
+
116
+ with tempfile .NamedTemporaryFile (dir= tempfile_dir) as fd:
117
+ if (content_encoding and content_encoding .lower () == ' gzip' ) or (user_content_encoding and user_content_encoding .lower () == ' gzip' ):
118
+ with gzip .GzipFile (fileobj= body) as gzipfile:
119
+ while fd .write (gzipfile .read (204800 )):
120
+ pass
121
+ else:
122
+ while fd .write (body .read (204800 )):
123
+ pass
124
+ fd .flush ()
125
+
126
+ res = plpy .execute (" COPY {table_name} {formatted_column_list} FROM {filename} {options};" .format(
127
+ table_name= table_name,
128
+ filename= plpy .quote_literal (fd .name ),
129
+ formatted_column_list= formatted_column_list,
130
+ options= options
131
+ )
132
+ )
133
+ num_rows + = res .nrows ()
134
+ return num_rows
116
135
$$;
117
136
118
137
--
@@ -126,14 +145,15 @@ CREATE OR REPLACE FUNCTION aws_s3.table_import_from_s3(
126
145
s3_info aws_commons ._s3_uri_1 ,
127
146
credentials aws_commons ._aws_credentials_1 ,
128
147
endpoint_url text default null ,
129
- content_encoding text default null
148
+ read_timeout integer default 60 ,
149
+ tempfile_dir text default ' /var/lib/postgresql/data/'
130
150
) RETURNS INT
131
151
LANGUAGE plpython3u
132
152
AS $$
133
153
134
154
plan = plpy .prepare (
135
- ' SELECT aws_s3.table_import_from_s3($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) AS num_rows' ,
136
- [' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' ]
155
+ ' SELECT aws_s3.table_import_from_s3($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 ) AS num_rows' ,
156
+ [' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' INTEGER ' , ' TEXT' ]
137
157
)
138
158
return plan .execute (
139
159
[
@@ -146,8 +166,8 @@ AS $$
146
166
credentials[' access_key' ],
147
167
credentials[' secret_key' ],
148
168
credentials[' session_token' ],
149
- endpoint_url,
150
- content_encoding
169
+ endpoint_url,
170
+ read_timeout
151
171
]
152
172
)[0 ][' num_rows' ]
153
173
$$;
@@ -162,6 +182,9 @@ CREATE OR REPLACE FUNCTION aws_s3.query_export_to_s3(
162
182
session_token text default null ,
163
183
options text default null ,
164
184
endpoint_url text default null ,
185
+ read_timeout integer default 60 ,
186
+ override boolean default false,
187
+ tempfile_dir text default ' /var/lib/postgresql/data/' ,
165
188
OUT rows_uploaded bigint ,
166
189
OUT files_uploaded bigint ,
167
190
OUT bytes_uploaded bigint
@@ -180,8 +203,19 @@ AS $$
180
203
module_cache[module_name] = _module
181
204
return _module
182
205
206
+ def file_exists(bucket, file_path, s3_client):
207
+ try:
208
+ s3_client .head_object (Bucket= bucket, Key= file_path)
209
+ return True
210
+ except:
211
+ return False
212
+
213
+ def get_unique_file_path(base_name, counter, extension):
214
+ return f" {base_name}_part{counter}{extension}"
215
+
183
216
boto3 = cache_import(' boto3' )
184
217
tempfile = cache_import(' tempfile' )
218
+ re = cache_import(" re" )
185
219
186
220
plan = plpy .prepare (" select name, current_setting('aws_s3.' || name, true) as value from (select unnest(array['access_key_id', 'secret_access_key', 'session_token', 'endpoint_url']) as name) a" );
187
221
default_aws_settings = {
@@ -199,10 +233,22 @@ AS $$
199
233
s3 = boto3 .client (
200
234
' s3' ,
201
235
region_name= region,
236
+ config= boto3 .session .Config(read_timeout= read_timeout),
202
237
** aws_settings
203
238
)
204
239
205
- with tempfile .NamedTemporaryFile () as fd:
240
+ upload_file_path = file_path
241
+ if not override:
242
+ # generate unique file path
243
+ file_path_parts = re .match (r' ^(.*?)(\. [^.]*$|$)' , upload_file_path)
244
+ base_name = file_path_parts .group (1 )
245
+ extension = file_path_parts .group (2 )
246
+ counter = 0
247
+ while file_exists(bucket, get_unique_file_path(base_name, counter, extension), s3):
248
+ counter + = 1
249
+ upload_file_path = get_unique_file_path(base_name, counter, extension)
250
+
251
+ with tempfile .NamedTemporaryFile (dir= tempfile_dir) as fd:
206
252
plan = plpy .prepare (
207
253
" COPY ({query}) TO '{filename}' {options}" .format(
208
254
query= query,
@@ -221,7 +267,7 @@ AS $$
221
267
num_lines + = buffer .count (b' \n ' )
222
268
size + = len(buffer)
223
269
fd .seek (0 )
224
- s3 .upload_fileobj (fd, bucket, file_path )
270
+ s3 .upload_fileobj (fd, bucket, upload_file_path )
225
271
if ' HEADER TRUE' in options .upper ():
226
272
num_lines - = 1
227
273
yield (num_lines, 1 , size)
@@ -233,15 +279,18 @@ CREATE OR REPLACE FUNCTION aws_s3.query_export_to_s3(
233
279
credentials aws_commons ._aws_credentials_1 default null ,
234
280
options text default null ,
235
281
endpoint_url text default null ,
282
+ read_timeout integer default 60 ,
283
+ override boolean default false,
284
+ tempfile_dir text default ' /var/lib/postgresql/data/' ,
236
285
OUT rows_uploaded bigint ,
237
286
OUT files_uploaded bigint ,
238
287
OUT bytes_uploaded bigint
239
288
) RETURNS SETOF RECORD
240
289
LANGUAGE plpython3u
241
290
AS $$
242
291
plan = plpy .prepare (
243
- ' SELECT * FROM aws_s3.query_export_to_s3($1, $2, $3, $4, $5, $6, $7, $8, $9)' ,
244
- [' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' ]
292
+ ' SELECT * FROM aws_s3.query_export_to_s3($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 )' ,
293
+ [' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' TEXT' , ' INTEGER ' , ' BOOLEAN ' , ' TEXT ' ]
245
294
)
246
295
return plan .execute (
247
296
[
@@ -253,7 +302,8 @@ AS $$
253
302
credentials .get (' secret_key' ) if credentials else None,
254
303
credentials .get (' session_token' ) if credentials else None,
255
304
options,
256
- endpoint_url
305
+ endpoint_url,
306
+ read_timeout
257
307
]
258
308
)
259
309
$$;
0 commit comments