Skip to content

Commit a27a406

Browse files
committed
Allow passing precomputed SCRAM keys via Config
According to https://datatracker.ietf.org/doc/html/rfc5802#section-3, SCRAM protocol explicitly allows client to use a `ClientKey` & `ServerKey` pair instead of a password to perform authentication. This is also useful for proxy implementations which would like to leverage `rust-postgres`. This patch adds the ability to do that.
1 parent 2b4beff commit a27a406

File tree

4 files changed

+125
-45
lines changed

4 files changed

+125
-45
lines changed

postgres-protocol/src/authentication/sasl.rs

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,32 @@ impl ChannelBinding {
9696
}
9797
}
9898

99+
/// A pair of keys for the SCRAM-SHA-256 mechanism.
100+
/// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
101+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102+
pub struct ScramKeys<const N: usize> {
103+
/// Used by server to authenticate client.
104+
pub client_key: [u8; N],
105+
/// Used by client to verify server's signature.
106+
pub server_key: [u8; N],
107+
}
108+
109+
/// Password or keys which were derived from it.
110+
enum Credentials<const N: usize> {
111+
/// A regular password as a vector of bytes.
112+
Password(Vec<u8>),
113+
/// A precomputed pair of keys.
114+
Keys(Box<ScramKeys<N>>),
115+
}
116+
99117
enum State {
100118
Update {
101119
nonce: String,
102-
password: Vec<u8>,
120+
password: Credentials<32>,
103121
channel_binding: ChannelBinding,
104122
},
105123
Finish {
106-
salted_password: [u8; 32],
124+
server_key: [u8; 32],
107125
auth_message: String,
108126
},
109127
Done,
@@ -129,30 +147,43 @@ pub struct ScramSha256 {
129147
state: State,
130148
}
131149

150+
fn nonce() -> String {
151+
// rand 0.5's ThreadRng is cryptographically secure
152+
let mut rng = rand::thread_rng();
153+
(0..NONCE_LENGTH)
154+
.map(|_| {
155+
let mut v = rng.gen_range(0x21u8..0x7e);
156+
if v == 0x2c {
157+
v = 0x7e
158+
}
159+
v as char
160+
})
161+
.collect()
162+
}
163+
132164
impl ScramSha256 {
133165
/// Constructs a new instance which will use the provided password for authentication.
134166
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
135-
// rand 0.5's ThreadRng is cryptographically secure
136-
let mut rng = rand::thread_rng();
137-
let nonce = (0..NONCE_LENGTH)
138-
.map(|_| {
139-
let mut v = rng.gen_range(0x21u8..0x7e);
140-
if v == 0x2c {
141-
v = 0x7e
142-
}
143-
v as char
144-
})
145-
.collect::<String>();
167+
let password = Credentials::Password(normalize(password));
168+
ScramSha256::new_inner(password, channel_binding, nonce())
169+
}
146170

147-
ScramSha256::new_inner(password, channel_binding, nonce)
171+
/// Constructs a new instance which will use the provided key pair for authentication.
172+
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
173+
let password = Credentials::Keys(keys.into());
174+
ScramSha256::new_inner(password, channel_binding, nonce())
148175
}
149176

150-
fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
177+
fn new_inner(
178+
password: Credentials<32>,
179+
channel_binding: ChannelBinding,
180+
nonce: String,
181+
) -> ScramSha256 {
151182
ScramSha256 {
152183
message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
153184
state: State::Update {
154185
nonce,
155-
password: normalize(password),
186+
password,
156187
channel_binding,
157188
},
158189
}
@@ -189,20 +220,32 @@ impl ScramSha256 {
189220
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
190221
}
191222

192-
let salt = match base64::decode(parsed.salt) {
193-
Ok(salt) => salt,
194-
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
195-
};
223+
let (client_key, server_key) = match password {
224+
Credentials::Password(password) => {
225+
let salt = match base64::decode(parsed.salt) {
226+
Ok(salt) => salt,
227+
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
228+
};
196229

197-
let salted_password = hi(&password, &salt, parsed.iteration_count);
230+
let salted_password = hi(&password, &salt, parsed.iteration_count);
198231

199-
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
200-
.expect("HMAC is able to accept all key sizes");
201-
hmac.update(b"Client Key");
202-
let client_key = hmac.finalize().into_bytes();
232+
let make_key = |name| {
233+
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
234+
.expect("HMAC is able to accept all key sizes");
235+
hmac.update(name);
236+
237+
let mut key = [0u8; 32];
238+
key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
239+
key
240+
};
241+
242+
(make_key(b"Client Key"), make_key(b"Server Key"))
243+
}
244+
Credentials::Keys(keys) => (keys.client_key, keys.server_key),
245+
};
203246

204247
let mut hash = Sha256::default();
205-
hash.update(client_key.as_slice());
248+
hash.update(client_key);
206249
let stored_key = hash.finalize_fixed();
207250

208251
let mut cbind_input = vec![];
@@ -225,10 +268,10 @@ impl ScramSha256 {
225268
*proof ^= signature;
226269
}
227270

228-
write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap();
271+
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
229272

230273
self.state = State::Finish {
231-
salted_password,
274+
server_key,
232275
auth_message,
233276
};
234277
Ok(())
@@ -239,11 +282,11 @@ impl ScramSha256 {
239282
/// This should be called when the backend sends an `AuthenticationSASLFinal` message.
240283
/// Authentication has only succeeded if this method returns `Ok(())`.
241284
pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
242-
let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
285+
let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
243286
State::Finish {
244-
salted_password,
287+
server_key,
245288
auth_message,
246-
} => (salted_password, auth_message),
289+
} => (server_key, auth_message),
247290
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
248291
};
249292

@@ -267,11 +310,6 @@ impl ScramSha256 {
267310
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
268311
};
269312

270-
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
271-
.expect("HMAC is able to accept all key sizes");
272-
hmac.update(b"Server Key");
273-
let server_key = hmac.finalize().into_bytes();
274-
275313
let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
276314
.expect("HMAC is able to accept all key sizes");
277315
hmac.update(auth_message.as_bytes());
@@ -458,7 +496,7 @@ mod test {
458496
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
459497

460498
let mut scram = ScramSha256::new_inner(
461-
password.as_bytes(),
499+
Credentials::Password(normalize(password.as_bytes())),
462500
ChannelBinding::unsupported(),
463501
nonce.to_string(),
464502
);

postgres/src/config.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use std::sync::Arc;
1212
use std::time::Duration;
1313
use tokio::runtime;
1414
#[doc(inline)]
15-
pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs};
15+
pub use tokio_postgres::config::{
16+
AuthKeys, ChannelBinding, Host, ScramKeys, SslMode, TargetSessionAttrs,
17+
};
1618
use tokio_postgres::error::DbError;
1719
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
1820
use tokio_postgres::{Error, Socket};
@@ -149,6 +151,20 @@ impl Config {
149151
self.config.get_password()
150152
}
151153

154+
/// Sets precomputed protocol-specific keys to authenticate with.
155+
/// When set, this option will override `password`.
156+
/// See [`AuthKeys`] for more information.
157+
pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
158+
self.config.auth_keys(keys);
159+
self
160+
}
161+
162+
/// Gets precomputed protocol-specific keys to authenticate with.
163+
/// if one has been configured with the `auth_keys` method.
164+
pub fn get_auth_keys(&self) -> Option<AuthKeys> {
165+
self.config.get_auth_keys()
166+
}
167+
152168
/// Sets the name of the database to connect to.
153169
///
154170
/// Defaults to the user.

tokio-postgres/src/config.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ use std::time::Duration;
2323
use std::{error, fmt, iter, mem};
2424
use tokio::io::{AsyncRead, AsyncWrite};
2525

26+
pub use postgres_protocol::authentication::sasl::ScramKeys;
27+
2628
/// Properties required of a session.
2729
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
2830
#[non_exhaustive]
@@ -79,6 +81,13 @@ pub enum Host {
7981
Unix(PathBuf),
8082
}
8183

84+
/// Precomputed keys which may override password during auth.
85+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86+
pub enum AuthKeys {
87+
/// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
88+
ScramSha256(ScramKeys<32>),
89+
}
90+
8291
/// Connection configuration.
8392
///
8493
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
@@ -163,6 +172,7 @@ pub enum Host {
163172
pub struct Config {
164173
pub(crate) user: Option<String>,
165174
pub(crate) password: Option<Vec<u8>>,
175+
pub(crate) auth_keys: Option<Box<AuthKeys>>,
166176
pub(crate) dbname: Option<String>,
167177
pub(crate) options: Option<String>,
168178
pub(crate) application_name: Option<String>,
@@ -194,6 +204,7 @@ impl Config {
194204
Config {
195205
user: None,
196206
password: None,
207+
auth_keys: None,
197208
dbname: None,
198209
options: None,
199210
application_name: None,
@@ -238,6 +249,20 @@ impl Config {
238249
self.password.as_deref()
239250
}
240251

252+
/// Sets precomputed protocol-specific keys to authenticate with.
253+
/// When set, this option will override `password`.
254+
/// See [`AuthKeys`] for more information.
255+
pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
256+
self.auth_keys = Some(Box::new(keys));
257+
self
258+
}
259+
260+
/// Gets precomputed protocol-specific keys to authenticate with.
261+
/// if one has been configured with the `auth_keys` method.
262+
pub fn get_auth_keys(&self) -> Option<AuthKeys> {
263+
self.auth_keys.as_deref().copied()
264+
}
265+
241266
/// Sets the name of the database to connect to.
242267
///
243268
/// Defaults to the user.

tokio-postgres/src/connect_raw.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2-
use crate::config::{self, Config, ReplicationMode};
2+
use crate::config::{self, AuthKeys, Config, ReplicationMode};
33
use crate::connect_tls::connect_tls;
44
use crate::maybe_tls_stream::MaybeTlsStream;
55
use crate::tls::{TlsConnect, TlsStream};
@@ -234,11 +234,6 @@ where
234234
S: AsyncRead + AsyncWrite + Unpin,
235235
T: TlsStream + Unpin,
236236
{
237-
let password = config
238-
.password
239-
.as_ref()
240-
.ok_or_else(|| Error::config("password missing".into()))?;
241-
242237
let mut has_scram = false;
243238
let mut has_scram_plus = false;
244239
let mut mechanisms = body.mechanisms();
@@ -276,7 +271,13 @@ where
276271
can_skip_channel_binding(config)?;
277272
}
278273

279-
let mut scram = ScramSha256::new(password, channel_binding);
274+
let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
275+
ScramSha256::new_with_keys(keys, channel_binding)
276+
} else if let Some(password) = config.get_password() {
277+
ScramSha256::new(password, channel_binding)
278+
} else {
279+
return Err(Error::config("password or auth keys missing".into()));
280+
};
280281

281282
let mut buf = BytesMut::new();
282283
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;

0 commit comments

Comments
 (0)