Skip to content

Add Asynchronous Connection Pooling Support to PostgresChatMessageHistory #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
80 changes: 65 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Feel free to use the abstraction as provided or else modify them / extend them a

## Requirements

The package currently only supports the [psycogp3](https://www.psycopg.org/psycopg3/) driver.
- [psycopg3](https://www.psycopg.org/psycopg3/): The PostgreSQL driver.
- [psycopg_pool](https://www.psycopg.org/psycopg3/docs/advanced/pool.html): For connection pooling support.

## Installation

Expand All @@ -25,24 +26,23 @@ pip install -U langchain-postgres

## Change Log

0.0.6:
- Remove langgraph as a dependency as it was causing dependency conflicts.
- Base interface for checkpointer changed in langgraph, so existing implementation would've broken regardless.
**0.0.7:**

- Added support for asynchronous connection pooling in `PostgresChatMessageHistory`.
- Adjusted parameter order in `PostgresChatMessageHistory` to make `session_id` the first parameter.

## Usage

### ChatMessageHistory

The chat message history abstraction helps to persist chat message history
in a postgres table.
The chat message history abstraction helps to persist chat message history in a Postgres table.

PostgresChatMessageHistory is parameterized using a `table_name` and a `session_id`.
`PostgresChatMessageHistory` is parameterized using a `session_id` and an optional `table_name` (default is `"chat_history"`).

The `table_name` is the name of the table in the database where
the chat messages will be stored.
- **`session_id`:** A unique identifier for the chat session. It can be assigned using `uuid.uuid4()`.
- **`table_name`:** The name of the table in the database where the chat messages will be stored.

The `session_id` is a unique identifier for the chat session. It can be assigned
by the caller using `uuid.uuid4()`.
#### **Synchronous Usage**

```python
import uuid
Expand All @@ -52,8 +52,7 @@ from langchain_postgres import PostgresChatMessageHistory
import psycopg

# Establish a synchronous connection to the database
# (or use psycopg.AsyncConnection for async)
conn_info = ... # Fill in with your connection info
conn_info = "postgresql://user:password@host:port/dbname" # Replace with your connection info
sync_connection = psycopg.connect(conn_info)

# Create the table schema (only needs to be done once)
Expand All @@ -64,8 +63,8 @@ session_id = str(uuid.uuid4())

# Initialize the chat history manager
chat_history = PostgresChatMessageHistory(
table_name,
session_id,
session_id=session_id,
table_name=table_name,
sync_connection=sync_connection
)

Expand All @@ -79,6 +78,57 @@ chat_history.add_messages([
print(chat_history.messages)
```

#### **Asynchronous Usage with Connection Pooling**

```python
import uuid
import asyncio

from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_postgres import PostgresChatMessageHistory
from psycopg_pool import AsyncConnectionPool

# Asynchronous main function
async def main():
# Database connection string
conn_info = "postgresql://user:password@host:port/dbname" # Replace with your connection info

# Initialize the connection pool
pool = AsyncConnectionPool(conninfo=conn_info)

try:
# Create the table schema (only needs to be done once)
async with pool.connection() as async_connection:
table_name = "chat_history"
await PostgresChatMessageHistory.adrop_table(async_connection, table_name)
await PostgresChatMessageHistory.acreate_tables(async_connection, table_name)

session_id = str(uuid.uuid4())

# Initialize the chat history manager with the connection pool
chat_history = PostgresChatMessageHistory(
session_id=session_id,
table_name=table_name,
conn_pool=pool
)

# Add messages to the chat history asynchronously
await chat_history.aadd_messages([
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
])

# Retrieve messages from the chat history
messages = await chat_history.aget_messages()
print(messages)
finally:
# Close the connection pool
await pool.close()

# Run the async main function
asyncio.run(main())
```

### Vectorstore

Expand Down
82 changes: 54 additions & 28 deletions langchain_postgres/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from typing import List, Optional, Sequence

import psycopg
from psycopg_pool import AsyncConnectionPool
from typing import Optional, Union, AsyncGenerator
from contextlib import asynccontextmanager
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
from psycopg import sql
Expand Down Expand Up @@ -77,12 +80,12 @@ def _insert_message_query(table_name: str) -> sql.Composed:
class PostgresChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
table_name: str,
session_id: str,
/,
table_name: str = "chat_history",
*,
sync_connection: Optional[psycopg.Connection] = None,
async_connection: Optional[psycopg.AsyncConnection] = None,
conn_pool: Optional[AsyncConnectionPool] = None,
) -> None:
"""Client for persisting chat message history in a Postgres database,

Expand Down Expand Up @@ -132,6 +135,8 @@ def __init__(
table_name: The name of the database table to use
sync_connection: An existing psycopg connection instance
async_connection: An existing psycopg async connection instance
conn_pool: AsyncConnectionPool instance for managing async connections.


Usage:
- Use the create_tables or acreate_tables method to set up the table
Expand Down Expand Up @@ -181,11 +186,14 @@ def __init__(

print(chat_history.messages)
"""
if not sync_connection and not async_connection:
raise ValueError("Must provide sync_connection or async_connection")
if not sync_connection and not async_connection and not conn_pool:
raise ValueError(
"Must provide sync_connection, async_connection, or conn_pool."
)

self._connection = sync_connection
self._aconnection = async_connection
self._conn_pool = conn_pool

# Validate that session id is a UUID
try:
Expand Down Expand Up @@ -290,23 +298,33 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
self._connection.commit()

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add messages to the chat message history."""
if self._aconnection is None:
"""Add messages to the chat message history asynchronously."""
if self._conn_pool is not None:
values = [
(self._session_id, json.dumps(message_to_dict(message)))
for message in messages
]
async with self._conn_pool.connection() as async_connection:
query = self._insert_message_query(self._table_name)
async with async_connection.cursor() as cursor:
await cursor.executemany(query, values)
await async_connection.commit()
elif self._aconnection is not None:
# Existing code using self._aconnection
values = [
(self._session_id, json.dumps(message_to_dict(message)))
for message in messages
]
query = self._insert_message_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.executemany(query, values)
await self._aconnection.commit()
else:
raise ValueError(
"Please initialize the PostgresChatMessageHistory "
"with an async connection or use the sync add_messages method instead."
"with an async connection or connection pool."
)

values = [
(self._session_id, json.dumps(message_to_dict(message)))
for message in messages
]

query = _insert_message_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.executemany(query, values)
await self._aconnection.commit()

def get_messages(self) -> List[BaseMessage]:
"""Retrieve messages from the chat message history."""
if self._connection is None:
Expand All @@ -325,21 +343,29 @@ def get_messages(self) -> List[BaseMessage]:
return messages

async def aget_messages(self) -> List[BaseMessage]:
"""Retrieve messages from the chat message history."""
if self._aconnection is None:
"""Retrieve messages from the chat message history asynchronously."""
if self._conn_pool is not None:
async with self._conn_pool.connection() as async_connection:
query = self._get_messages_query(self._table_name)
async with async_connection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
items = [record[0] for record in await cursor.fetchall()]
messages = messages_from_dict(items)
return messages
elif self._aconnection is not None:
# Existing code using self._aconnection
query = self._get_messages_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
items = [record[0] for record in await cursor.fetchall()]
messages = messages_from_dict(items)
return messages
else:
raise ValueError(
"Please initialize the PostgresChatMessageHistory "
"with an async connection or use the sync get_messages method instead."
"with an async connection or connection pool."
)

query = _get_messages_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
items = [record[0] for record in await cursor.fetchall()]

messages = messages_from_dict(items)
return messages

@property # type: ignore[override]
def messages(self) -> List[BaseMessage]:
"""The abstraction required a property."""
Expand Down
75 changes: 74 additions & 1 deletion tests/unit_tests/test_chat_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
from tests.utils import asyncpg_client, syncpg_client
from psycopg_pool import AsyncConnectionPool
from tests.utils import asyncpg_client, syncpg_client, DSN


def test_sync_chat_history() -> None:
Expand Down Expand Up @@ -121,3 +122,75 @@ async def test_async_chat_history() -> None:
# clear
await chat_history.aclear()
assert await chat_history.aget_messages() == []


async def test_async_chat_history_with_pool() -> None:
"""Test the async chat history using a connection pool."""
# Initialize the connection pool
pool = AsyncConnectionPool(conninfo=DSN)
try:
table_name = "chat_history"
session_id = str(uuid.uuid4())

# Create tables using a connection from the pool
async with pool.connection() as async_connection:
await PostgresChatMessageHistory.adrop_table(async_connection, table_name)
await PostgresChatMessageHistory.acreate_tables(async_connection, table_name)

# Create PostgresChatMessageHistory with conn_pool
chat_history = PostgresChatMessageHistory(
session_id=session_id,
table_name=table_name,
conn_pool=pool,
)

# Ensure the chat history is empty
messages = await chat_history.aget_messages()
assert messages == []

# Add messages to the chat history
await chat_history.aadd_messages(
[
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
]
)

# Retrieve messages from the chat history
messages = await chat_history.aget_messages()
assert len(messages) == 3
assert messages == [
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
]

# Add more messages
await chat_history.aadd_messages(
[
SystemMessage(content="Another system message"),
AIMessage(content="Another AI response"),
HumanMessage(content="Another human message"),
]
)

# Verify all messages are retrieved
messages = await chat_history.aget_messages()
assert len(messages) == 6
assert messages == [
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
SystemMessage(content="Another system message"),
AIMessage(content="Another AI response"),
HumanMessage(content="Another human message"),
]

# Clear the chat history
await chat_history.aclear()
messages = await chat_history.aget_messages()
assert messages == []
finally:
# Close the connection pool
await pool.close()