Skip to content

Commit 3c64e0a

Browse files
committed
add copy_both_simple method
Signed-off-by: Petros Angelatos <petrosagg@gmail.com>
1 parent 148b66c commit 3c64e0a

File tree

4 files changed

+283
-2
lines changed

4 files changed

+283
-2
lines changed

tokio-postgres/src/client.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::codec::BackendMessages;
22
use crate::config::{Host, SslMode};
33
use crate::connection::{Request, RequestMessages};
4+
use crate::copy_both::CopyBothDuplex;
45
use crate::copy_out::CopyOutStream;
56
use crate::query::RowStream;
67
use crate::simple_query::SimpleQueryStream;
@@ -11,8 +12,9 @@ use crate::types::{Oid, ToSql, Type};
1112
#[cfg(feature = "runtime")]
1213
use crate::Socket;
1314
use crate::{
14-
copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
15-
Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
15+
copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken,
16+
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
17+
TransactionBuilder,
1618
};
1719
use bytes::{Buf, BytesMut};
1820
use fallible_iterator::FallibleIterator;
@@ -433,6 +435,15 @@ impl Client {
433435
copy_out::copy_out(self.inner(), statement).await
434436
}
435437

438+
/// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy
439+
/// data.
440+
pub async fn copy_both_simple<T>(&self, query: &str) -> Result<CopyBothDuplex<T>, Error>
441+
where
442+
T: Buf + 'static + Send,
443+
{
444+
copy_both::copy_both_simple(self.inner(), query).await
445+
}
446+
436447
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
437448
///
438449
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that

tokio-postgres/src/connection.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2+
use crate::copy_both::CopyBothReceiver;
23
use crate::copy_in::CopyInReceiver;
34
use crate::error::DbError;
45
use crate::maybe_tls_stream::MaybeTlsStream;
@@ -21,6 +22,7 @@ use tokio_util::codec::Framed;
2122
pub enum RequestMessages {
2223
Single(FrontendMessage),
2324
CopyIn(CopyInReceiver),
25+
CopyBoth(CopyBothReceiver),
2426
}
2527

2628
pub struct Request {
@@ -259,6 +261,24 @@ where
259261
.map_err(Error::io)?;
260262
self.pending_request = Some(RequestMessages::CopyIn(receiver));
261263
}
264+
RequestMessages::CopyBoth(mut receiver) => {
265+
let message = match receiver.poll_next_unpin(cx) {
266+
Poll::Ready(Some(message)) => message,
267+
Poll::Ready(None) => {
268+
trace!("poll_write: finished copy_both request");
269+
continue;
270+
}
271+
Poll::Pending => {
272+
trace!("poll_write: waiting on copy_both stream");
273+
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
274+
return Ok(true);
275+
}
276+
};
277+
Pin::new(&mut self.stream)
278+
.start_send(message)
279+
.map_err(Error::io)?;
280+
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
281+
}
262282
}
263283
}
264284
}

tokio-postgres/src/copy_both.rs

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
use crate::client::{InnerClient, Responses};
2+
use crate::codec::FrontendMessage;
3+
use crate::connection::RequestMessages;
4+
use crate::{simple_query, Error};
5+
use bytes::{Buf, BufMut, Bytes, BytesMut};
6+
use futures::channel::mpsc;
7+
use futures::future;
8+
use futures::{ready, Sink, SinkExt, Stream, StreamExt};
9+
use log::debug;
10+
use pin_project_lite::pin_project;
11+
use postgres_protocol::message::backend::Message;
12+
use postgres_protocol::message::frontend;
13+
use postgres_protocol::message::frontend::CopyData;
14+
use std::marker::{PhantomData, PhantomPinned};
15+
use std::pin::Pin;
16+
use std::task::{Context, Poll};
17+
18+
pub(crate) enum CopyBothMessage {
19+
Message(FrontendMessage),
20+
Done,
21+
}
22+
23+
pub struct CopyBothReceiver {
24+
receiver: mpsc::Receiver<CopyBothMessage>,
25+
done: bool,
26+
}
27+
28+
impl CopyBothReceiver {
29+
pub(crate) fn new(receiver: mpsc::Receiver<CopyBothMessage>) -> CopyBothReceiver {
30+
CopyBothReceiver {
31+
receiver,
32+
done: false,
33+
}
34+
}
35+
}
36+
37+
impl Stream for CopyBothReceiver {
38+
type Item = FrontendMessage;
39+
40+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
41+
if self.done {
42+
return Poll::Ready(None);
43+
}
44+
45+
match ready!(self.receiver.poll_next_unpin(cx)) {
46+
Some(CopyBothMessage::Message(message)) => Poll::Ready(Some(message)),
47+
Some(CopyBothMessage::Done) => {
48+
self.done = true;
49+
let mut buf = BytesMut::new();
50+
frontend::copy_done(&mut buf);
51+
frontend::sync(&mut buf);
52+
Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
53+
}
54+
None => {
55+
self.done = true;
56+
let mut buf = BytesMut::new();
57+
frontend::copy_fail("", &mut buf).unwrap();
58+
frontend::sync(&mut buf);
59+
Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
60+
}
61+
}
62+
}
63+
}
64+
65+
enum SinkState {
66+
Active,
67+
Closing,
68+
Reading,
69+
}
70+
71+
pin_project! {
72+
/// A sink for `COPY ... FROM STDIN` query data.
73+
///
74+
/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
75+
/// not, the copy will be aborted.
76+
pub struct CopyBothDuplex<T> {
77+
#[pin]
78+
sender: mpsc::Sender<CopyBothMessage>,
79+
responses: Responses,
80+
buf: BytesMut,
81+
state: SinkState,
82+
#[pin]
83+
_p: PhantomPinned,
84+
_p2: PhantomData<T>,
85+
}
86+
}
87+
88+
impl<T> CopyBothDuplex<T>
89+
where
90+
T: Buf + 'static + Send,
91+
{
92+
pub(crate) fn new(sender: mpsc::Sender<CopyBothMessage>, responses: Responses) -> Self {
93+
Self {
94+
sender,
95+
responses,
96+
buf: BytesMut::new(),
97+
state: SinkState::Active,
98+
_p: PhantomPinned,
99+
_p2: PhantomData,
100+
}
101+
}
102+
103+
/// A poll-based version of `finish`.
104+
pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> {
105+
loop {
106+
match self.state {
107+
SinkState::Active => {
108+
ready!(self.as_mut().poll_flush(cx))?;
109+
let mut this = self.as_mut().project();
110+
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
111+
this.sender
112+
.start_send(CopyBothMessage::Done)
113+
.map_err(|_| Error::closed())?;
114+
*this.state = SinkState::Closing;
115+
}
116+
SinkState::Closing => {
117+
let this = self.as_mut().project();
118+
ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?;
119+
*this.state = SinkState::Reading;
120+
}
121+
SinkState::Reading => {
122+
let this = self.as_mut().project();
123+
match ready!(this.responses.poll_next(cx))? {
124+
Message::CommandComplete(body) => {
125+
let rows = body
126+
.tag()
127+
.map_err(Error::parse)?
128+
.rsplit(' ')
129+
.next()
130+
.unwrap()
131+
.parse()
132+
.unwrap_or(0);
133+
return Poll::Ready(Ok(rows));
134+
}
135+
_ => return Poll::Ready(Err(Error::unexpected_message())),
136+
}
137+
}
138+
}
139+
}
140+
}
141+
142+
/// Completes the copy, returning the number of rows inserted.
143+
///
144+
/// The `Sink::close` method is equivalent to `finish`, except that it does not return the
145+
/// number of rows.
146+
pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
147+
future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
148+
}
149+
}
150+
151+
impl<T> Stream for CopyBothDuplex<T> {
152+
type Item = Result<Bytes, Error>;
153+
154+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
155+
let this = self.project();
156+
157+
match ready!(this.responses.poll_next(cx)?) {
158+
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
159+
Message::CopyDone => Poll::Ready(None),
160+
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
161+
}
162+
}
163+
}
164+
165+
impl<T> Sink<T> for CopyBothDuplex<T>
166+
where
167+
T: Buf + 'static + Send,
168+
{
169+
type Error = Error;
170+
171+
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
172+
self.project()
173+
.sender
174+
.poll_ready(cx)
175+
.map_err(|_| Error::closed())
176+
}
177+
178+
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
179+
let this = self.project();
180+
181+
let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
182+
if this.buf.is_empty() {
183+
Box::new(item)
184+
} else {
185+
Box::new(this.buf.split().freeze().chain(item))
186+
}
187+
} else {
188+
this.buf.put(item);
189+
if this.buf.len() > 4096 {
190+
Box::new(this.buf.split().freeze())
191+
} else {
192+
return Ok(());
193+
}
194+
};
195+
196+
let data = CopyData::new(data).map_err(Error::encode)?;
197+
this.sender
198+
.start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data)))
199+
.map_err(|_| Error::closed())
200+
}
201+
202+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
203+
let mut this = self.project();
204+
205+
if !this.buf.is_empty() {
206+
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
207+
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
208+
let data = CopyData::new(data).map_err(Error::encode)?;
209+
this.sender
210+
.as_mut()
211+
.start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data)))
212+
.map_err(|_| Error::closed())?;
213+
}
214+
215+
this.sender.poll_flush(cx).map_err(|_| Error::closed())
216+
}
217+
218+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
219+
self.poll_finish(cx).map_ok(|_| ())
220+
}
221+
}
222+
223+
pub async fn copy_both_simple<T>(
224+
client: &InnerClient,
225+
query: &str,
226+
) -> Result<CopyBothDuplex<T>, Error>
227+
where
228+
T: Buf + 'static + Send,
229+
{
230+
debug!("executing copy both query {}", query);
231+
232+
let buf = simple_query::encode(client, query)?;
233+
234+
let (mut sender, receiver) = mpsc::channel(1);
235+
let receiver = CopyBothReceiver::new(receiver);
236+
let mut responses = client.send(RequestMessages::CopyBoth(receiver))?;
237+
238+
sender
239+
.send(CopyBothMessage::Message(FrontendMessage::Raw(buf)))
240+
.await
241+
.map_err(|_| Error::closed())?;
242+
243+
match responses.next().await? {
244+
Message::CopyBothResponse(_) => {}
245+
_ => return Err(Error::unexpected_message()),
246+
}
247+
248+
Ok(CopyBothDuplex::new(sender, responses))
249+
}

tokio-postgres/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ mod connect_raw;
155155
mod connect_socket;
156156
mod connect_tls;
157157
mod connection;
158+
mod copy_both;
158159
mod copy_in;
159160
mod copy_out;
160161
pub mod error;

0 commit comments

Comments
 (0)