diff --git a/README.md b/README.md index 06c0a0f..0e69908 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ Feel free to use the abstraction as provided or else modify them / extend them a ## Requirements -The package currently only supports the [psycogp3](https://www.psycopg.org/psycopg3/) driver. +- [psycopg3](https://www.psycopg.org/psycopg3/): The PostgreSQL driver. +- [psycopg_pool](https://www.psycopg.org/psycopg3/docs/advanced/pool.html): For connection pooling support. ## Installation @@ -25,24 +26,23 @@ pip install -U langchain-postgres ## Change Log -0.0.6: -- Remove langgraph as a dependency as it was causing dependency conflicts. -- Base interface for checkpointer changed in langgraph, so existing implementation would've broken regardless. +**0.0.7:** + +- Added support for asynchronous connection pooling in `PostgresChatMessageHistory`. +- Adjusted parameter order in `PostgresChatMessageHistory` to make `session_id` the first parameter. ## Usage ### ChatMessageHistory -The chat message history abstraction helps to persist chat message history -in a postgres table. +The chat message history abstraction helps to persist chat message history in a Postgres table. -PostgresChatMessageHistory is parameterized using a `table_name` and a `session_id`. +`PostgresChatMessageHistory` is parameterized using a `session_id` and an optional `table_name` (default is `"chat_history"`). -The `table_name` is the name of the table in the database where -the chat messages will be stored. +- **`session_id`:** A unique identifier for the chat session. It can be assigned using `uuid.uuid4()`. +- **`table_name`:** The name of the table in the database where the chat messages will be stored. -The `session_id` is a unique identifier for the chat session. It can be assigned -by the caller using `uuid.uuid4()`. +#### **Synchronous Usage** ```python import uuid @@ -52,8 +52,7 @@ from langchain_postgres import PostgresChatMessageHistory import psycopg # Establish a synchronous connection to the database -# (or use psycopg.AsyncConnection for async) -conn_info = ... # Fill in with your connection info +conn_info = "postgresql://user:password@host:port/dbname" # Replace with your connection info sync_connection = psycopg.connect(conn_info) # Create the table schema (only needs to be done once) @@ -64,8 +63,8 @@ session_id = str(uuid.uuid4()) # Initialize the chat history manager chat_history = PostgresChatMessageHistory( - table_name, - session_id, + session_id=session_id, + table_name=table_name, sync_connection=sync_connection ) @@ -79,6 +78,57 @@ chat_history.add_messages([ print(chat_history.messages) ``` +#### **Asynchronous Usage with Connection Pooling** + +```python +import uuid +import asyncio + +from langchain_core.messages import SystemMessage, AIMessage, HumanMessage +from langchain_postgres import PostgresChatMessageHistory +from psycopg_pool import AsyncConnectionPool + +# Asynchronous main function +async def main(): + # Database connection string + conn_info = "postgresql://user:password@host:port/dbname" # Replace with your connection info + + # Initialize the connection pool + pool = AsyncConnectionPool(conninfo=conn_info) + + try: + # Create the table schema (only needs to be done once) + async with pool.connection() as async_connection: + table_name = "chat_history" + await PostgresChatMessageHistory.adrop_table(async_connection, table_name) + await PostgresChatMessageHistory.acreate_tables(async_connection, table_name) + + session_id = str(uuid.uuid4()) + + # Initialize the chat history manager with the connection pool + chat_history = PostgresChatMessageHistory( + session_id=session_id, + table_name=table_name, + conn_pool=pool + ) + + # Add messages to the chat history asynchronously + await chat_history.aadd_messages([ + SystemMessage(content="System message"), + AIMessage(content="AI response"), + HumanMessage(content="Human message"), + ]) + + # Retrieve messages from the chat history + messages = await chat_history.aget_messages() + print(messages) + finally: + # Close the connection pool + await pool.close() + +# Run the async main function +asyncio.run(main()) +``` ### Vectorstore diff --git a/langchain_postgres/chat_message_histories.py b/langchain_postgres/chat_message_histories.py index 85f2aef..605ff76 100644 --- a/langchain_postgres/chat_message_histories.py +++ b/langchain_postgres/chat_message_histories.py @@ -11,6 +11,9 @@ from typing import List, Optional, Sequence import psycopg +from psycopg_pool import AsyncConnectionPool +from typing import Optional, Union, AsyncGenerator +from contextlib import asynccontextmanager from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict from psycopg import sql @@ -77,12 +80,12 @@ def _insert_message_query(table_name: str) -> sql.Composed: class PostgresChatMessageHistory(BaseChatMessageHistory): def __init__( self, - table_name: str, session_id: str, - /, + table_name: str = "chat_history", *, sync_connection: Optional[psycopg.Connection] = None, async_connection: Optional[psycopg.AsyncConnection] = None, + conn_pool: Optional[AsyncConnectionPool] = None, ) -> None: """Client for persisting chat message history in a Postgres database, @@ -132,6 +135,8 @@ def __init__( table_name: The name of the database table to use sync_connection: An existing psycopg connection instance async_connection: An existing psycopg async connection instance + conn_pool: AsyncConnectionPool instance for managing async connections. + Usage: - Use the create_tables or acreate_tables method to set up the table @@ -181,11 +186,14 @@ def __init__( print(chat_history.messages) """ - if not sync_connection and not async_connection: - raise ValueError("Must provide sync_connection or async_connection") + if not sync_connection and not async_connection and not conn_pool: + raise ValueError( + "Must provide sync_connection, async_connection, or conn_pool." + ) self._connection = sync_connection self._aconnection = async_connection + self._conn_pool = conn_pool # Validate that session id is a UUID try: @@ -290,23 +298,33 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None: self._connection.commit() async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: - """Add messages to the chat message history.""" - if self._aconnection is None: + """Add messages to the chat message history asynchronously.""" + if self._conn_pool is not None: + values = [ + (self._session_id, json.dumps(message_to_dict(message))) + for message in messages + ] + async with self._conn_pool.connection() as async_connection: + query = self._insert_message_query(self._table_name) + async with async_connection.cursor() as cursor: + await cursor.executemany(query, values) + await async_connection.commit() + elif self._aconnection is not None: + # Existing code using self._aconnection + values = [ + (self._session_id, json.dumps(message_to_dict(message))) + for message in messages + ] + query = self._insert_message_query(self._table_name) + async with self._aconnection.cursor() as cursor: + await cursor.executemany(query, values) + await self._aconnection.commit() + else: raise ValueError( "Please initialize the PostgresChatMessageHistory " - "with an async connection or use the sync add_messages method instead." + "with an async connection or connection pool." ) - values = [ - (self._session_id, json.dumps(message_to_dict(message))) - for message in messages - ] - - query = _insert_message_query(self._table_name) - async with self._aconnection.cursor() as cursor: - await cursor.executemany(query, values) - await self._aconnection.commit() - def get_messages(self) -> List[BaseMessage]: """Retrieve messages from the chat message history.""" if self._connection is None: @@ -325,21 +343,29 @@ def get_messages(self) -> List[BaseMessage]: return messages async def aget_messages(self) -> List[BaseMessage]: - """Retrieve messages from the chat message history.""" - if self._aconnection is None: + """Retrieve messages from the chat message history asynchronously.""" + if self._conn_pool is not None: + async with self._conn_pool.connection() as async_connection: + query = self._get_messages_query(self._table_name) + async with async_connection.cursor() as cursor: + await cursor.execute(query, {"session_id": self._session_id}) + items = [record[0] for record in await cursor.fetchall()] + messages = messages_from_dict(items) + return messages + elif self._aconnection is not None: + # Existing code using self._aconnection + query = self._get_messages_query(self._table_name) + async with self._aconnection.cursor() as cursor: + await cursor.execute(query, {"session_id": self._session_id}) + items = [record[0] for record in await cursor.fetchall()] + messages = messages_from_dict(items) + return messages + else: raise ValueError( "Please initialize the PostgresChatMessageHistory " - "with an async connection or use the sync get_messages method instead." + "with an async connection or connection pool." ) - query = _get_messages_query(self._table_name) - async with self._aconnection.cursor() as cursor: - await cursor.execute(query, {"session_id": self._session_id}) - items = [record[0] for record in await cursor.fetchall()] - - messages = messages_from_dict(items) - return messages - @property # type: ignore[override] def messages(self) -> List[BaseMessage]: """The abstraction required a property.""" diff --git a/tests/unit_tests/test_chat_histories.py b/tests/unit_tests/test_chat_histories.py index 121ee2e..44459e1 100644 --- a/tests/unit_tests/test_chat_histories.py +++ b/tests/unit_tests/test_chat_histories.py @@ -3,7 +3,8 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_postgres.chat_message_histories import PostgresChatMessageHistory -from tests.utils import asyncpg_client, syncpg_client +from psycopg_pool import AsyncConnectionPool +from tests.utils import asyncpg_client, syncpg_client, DSN def test_sync_chat_history() -> None: @@ -121,3 +122,75 @@ async def test_async_chat_history() -> None: # clear await chat_history.aclear() assert await chat_history.aget_messages() == [] + + +async def test_async_chat_history_with_pool() -> None: + """Test the async chat history using a connection pool.""" + # Initialize the connection pool + pool = AsyncConnectionPool(conninfo=DSN) + try: + table_name = "chat_history" + session_id = str(uuid.uuid4()) + + # Create tables using a connection from the pool + async with pool.connection() as async_connection: + await PostgresChatMessageHistory.adrop_table(async_connection, table_name) + await PostgresChatMessageHistory.acreate_tables(async_connection, table_name) + + # Create PostgresChatMessageHistory with conn_pool + chat_history = PostgresChatMessageHistory( + session_id=session_id, + table_name=table_name, + conn_pool=pool, + ) + + # Ensure the chat history is empty + messages = await chat_history.aget_messages() + assert messages == [] + + # Add messages to the chat history + await chat_history.aadd_messages( + [ + SystemMessage(content="System message"), + AIMessage(content="AI response"), + HumanMessage(content="Human message"), + ] + ) + + # Retrieve messages from the chat history + messages = await chat_history.aget_messages() + assert len(messages) == 3 + assert messages == [ + SystemMessage(content="System message"), + AIMessage(content="AI response"), + HumanMessage(content="Human message"), + ] + + # Add more messages + await chat_history.aadd_messages( + [ + SystemMessage(content="Another system message"), + AIMessage(content="Another AI response"), + HumanMessage(content="Another human message"), + ] + ) + + # Verify all messages are retrieved + messages = await chat_history.aget_messages() + assert len(messages) == 6 + assert messages == [ + SystemMessage(content="System message"), + AIMessage(content="AI response"), + HumanMessage(content="Human message"), + SystemMessage(content="Another system message"), + AIMessage(content="Another AI response"), + HumanMessage(content="Another human message"), + ] + + # Clear the chat history + await chat_history.aclear() + messages = await chat_history.aget_messages() + assert messages == [] + finally: + # Close the connection pool + await pool.close()