diff --git a/langchain_postgres/__init__.py b/langchain_postgres/__init__.py index 15a9230f..6b73876d 100644 --- a/langchain_postgres/__init__.py +++ b/langchain_postgres/__init__.py @@ -2,6 +2,7 @@ from langchain_postgres.chat_message_histories import PostgresChatMessageHistory from langchain_postgres.translator import PGVectorTranslator +from langchain_postgres.v2.chat_message_history import PGChatMessageHistory from langchain_postgres.v2.engine import Column, ColumnDict, PGEngine from langchain_postgres.v2.vectorstores import PGVectorStore from langchain_postgres.vectorstores import PGVector @@ -18,6 +19,7 @@ "ColumnDict", "PGEngine", "PostgresChatMessageHistory", + "PGChatMessageHistory", "PGVector", "PGVectorStore", "PGVectorTranslator", diff --git a/langchain_postgres/v2/async_chat_message_history.py b/langchain_postgres/v2/async_chat_message_history.py new file mode 100644 index 00000000..abb82022 --- /dev/null +++ b/langchain_postgres/v2/async_chat_message_history.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import json +from typing import Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PGEngine + + +class AsyncPGChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + session_id: str, + table_name: str, + store_message: bool, + schema_name: str = "public", + ): + """AsyncPGChatMessageHistory constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PGEngine): Database connection pool. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + store_message (bool): Whether to store the whole message or store data & type seperately + schema_name (str): The schema name where the table is located (default: "public"). + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncPGChatMessageHistory.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self.pool = pool + self.session_id = session_id + self.table_name = table_name + self.schema_name = schema_name + self.store_message = store_message + + @classmethod + async def create( + cls, + engine: PGEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ) -> AsyncPGChatMessageHistory: + """Create a new AsyncPGChatMessageHistory instance. + + Args: + engine (PGEngine): PGEngine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name (str): The schema name where the table is located (default: "public"). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AsyncPGChatMessageHistory: A newly created instance of AsyncPGChatMessageHistory. + """ + column_names = await engine._aload_table_schema(table_name, schema_name) + + required_columns = ["id", "session_id", "data", "type"] + supported_columns = ["id", "session_id", "message", "created_at"] + + if not (all(x in column_names for x in required_columns)): + if not (all(x in column_names for x in supported_columns)): + raise IndexError( + f"Table '{schema_name}'.'{table_name}' has incorrect schema. Got " + f"column names '{column_names}' but required column names " + f"'{required_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.{table_name} (" + "\n id INT AUTO_INCREMENT PRIMARY KEY," + "\n session_id TEXT NOT NULL," + "\n data JSON NOT NULL," + "\n type TEXT NOT NULL" + "\n);" + ) + + store_message = True if "message" in column_names else False + + return cls( + cls.__create_key, + engine._pool, + session_id, + table_name, + store_message, + schema_name, + ) + + def _insert_query(self, message: BaseMessage) -> tuple[str, dict]: + if self.store_message: + query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, message) VALUES (:session_id, :message)""" + params = { + "message": json.dumps(message_to_dict(message)), + "session_id": self.session_id, + } + else: + query = f"""INSERT INTO "{self.schema_name}"."{self.table_name}"(session_id, data, type) VALUES (:session_id, :data, :type)""" + params = { + "data": json.dumps(message.model_dump()), + "session_id": self.session_id, + "type": message.type, + } + + return query, params + + async def aadd_message(self, message: BaseMessage) -> None: + """Append the message to the record in Postgres""" + query, params = self._insert_query(message) + async with self.pool.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Append a list of messages to the record in Postgres""" + for message in messages: + await self.aadd_message(message) + + async def aclear(self) -> None: + """Clear session memory from Postgres""" + query = f"""DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id;""" + async with self.pool.connect() as conn: + await conn.execute(text(query), {"session_id": self.session_id}) + await conn.commit() + + def _select_query(self) -> str: + if self.store_message: + return f"""SELECT message FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;""" + else: + return f"""SELECT data, type FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;""" + + def _convert_to_messages(self, rows: Sequence[RowMapping]) -> list[BaseMessage]: + if self.store_message: + items = [row["message"] for row in rows] + messages = messages_from_dict(items) + else: + items = [{"data": row["data"], "type": row["type"]} for row in rows] + messages = messages_from_dict(items) + return messages + + async def _aget_messages(self) -> list[BaseMessage]: + """Retrieve the messages from Postgres.""" + + query = self._select_query() + + async with self.pool.connect() as conn: + result = await conn.execute(text(query), {"session_id": self.session_id}) + result_map = result.mappings() + results = result_map.fetchall() + if not results: + return [] + + messages = self._convert_to_messages(results) + return messages + + def clear(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead." + ) + + def add_message(self, message: BaseMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead." + ) + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPGChatMessageHistory. Use PGChatMessageHistory interface instead." + ) diff --git a/langchain_postgres/v2/chat_message_history.py b/langchain_postgres/v2/chat_message_history.py new file mode 100644 index 00000000..548935ba --- /dev/null +++ b/langchain_postgres/v2/chat_message_history.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage + +from .async_chat_message_history import AsyncPGChatMessageHistory +from .engine import PGEngine + + +class PGChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: PGEngine, + history: AsyncPGChatMessageHistory, + ): + """PGChatMessageHistory constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PGEngine): Database connection pool. + history (AsyncPGChatMessageHistory): Async only implementation. + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != PGChatMessageHistory.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self._engine = engine + self.__history = history + + @classmethod + async def create( + cls, + engine: PGEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ) -> PGChatMessageHistory: + """Create a new PGChatMessageHistory instance. + + Args: + engine (PGEngine): PGEngine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name (str): The schema name where the table is located (default: "public"). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + PGChatMessageHistory: A newly created instance of PGChatMessageHistory. + """ + coro = AsyncPGChatMessageHistory.create( + engine, session_id, table_name, schema_name + ) + history = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, history) + + @classmethod + def create_sync( + cls, + engine: PGEngine, + session_id: str, + table_name: str, + schema_name: str = "public", + ) -> PGChatMessageHistory: + """Create a new PGChatMessageHistory instance. + + Args: + engine (PGEngine): PGEngine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name: The schema name where the table is located (default: "public"). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + PGChatMessageHistory: A newly created instance of PGChatMessageHistory. + """ + coro = AsyncPGChatMessageHistory.create( + engine, session_id, table_name, schema_name + ) + history = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, history) + + @property + def messages(self) -> list[BaseMessage]: + """Fetches all messages stored in Postgres.""" + return self._engine._run_as_sync(self.__history._aget_messages()) + + @messages.setter + def messages(self, value: list[BaseMessage]) -> None: + """Clear the stored messages and appends a list of messages to the record in Postgres.""" + self.clear() + self.add_messages(value) + + async def aadd_message(self, message: BaseMessage) -> None: + """Append the message to the record in Postgres""" + await self._engine._run_as_async(self.__history.aadd_message(message)) + + def add_message(self, message: BaseMessage) -> None: + """Append the message to the record in Postgres""" + self._engine._run_as_sync(self.__history.aadd_message(message)) + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Append a list of messages to the record in Postgres""" + await self._engine._run_as_async(self.__history.aadd_messages(messages)) + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + """Append a list of messages to the record in Postgres""" + self._engine._run_as_sync(self.__history.aadd_messages(messages)) + + async def aclear(self) -> None: + """Clear session memory from Postgres""" + await self._engine._run_as_async(self.__history.aclear()) + + def clear(self) -> None: + """Clear session memory from Postgres""" + self._engine._run_as_sync(self.__history.aclear()) diff --git a/langchain_postgres/v2/engine.py b/langchain_postgres/v2/engine.py index c2a0d931..5c97de36 100644 --- a/langchain_postgres/v2/engine.py +++ b/langchain_postgres/v2/engine.py @@ -347,6 +347,65 @@ def init_vectorstore_table( ) ) + async def _ainit_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: + """ + Create a postgres table to save chat history messages. + + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat history table. + Default: "public". + + Returns: + None + """ + create_table_query = f"""CREATE TABLE IF NOT EXISTS "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + session_id TEXT NOT NULL, + message JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + );""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.commit() + + async def ainit_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: + """Create a postgres table to save chat history messages. + + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store chat history table. + Default: "public". + + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_history_table( + table_name, + schema_name, + ) + ) + + def init_chat_history_table( + self, table_name: str, schema_name: str = "public" + ) -> None: + """Create a postgres table to store chat history. + + Args: + table_name (str): Table name to store chat history. + schema_name (str): The schema name to store chat history table. + Default: "public". + + Returns: + None + """ + self._run_as_sync(self._ainit_chat_history_table(table_name, schema_name)) + async def _adrop_table( self, table_name: str, @@ -378,3 +437,40 @@ async def drop_table( self._run_as_sync( self._adrop_table(table_name=table_name, schema_name=schema_name) ) + + async def _aload_table_schema( + self, table_name: str, schema_name: str = "public" + ) -> list[str]: + """ + Load table schema from an existing table in a PgSQL database, potentially from a specific database schema. + + Args: + table_name: The name of the table to load the table schema from. + schema_name: The name of the database schema where the table resides. + Default: "public". + + Returns: + (lsit[str]: list of all column names in the table.) + """ + + query = """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = :schema + AND table_name = :table + ORDER BY ordinal_position; + """ + + async with self._pool.connect() as conn: + result = await conn.execute( + text(query), {"schema": schema_name, "table": table_name} + ) + result_map = result.mappings() + results = result_map.fetchall() + + column_names = [row["column_name"] for row in results] + + if column_names: + return column_names + else: + raise ValueError(f'Table, "{schema_name}"."{table_name}", does not exist: ') diff --git a/tests/unit_tests/v2/test_async_chat_message_history.py b/tests/unit_tests/v2/test_async_chat_message_history.py new file mode 100644 index 00000000..30579a1e --- /dev/null +++ b/tests/unit_tests/v2/test_async_chat_message_history.py @@ -0,0 +1,139 @@ +import uuid +from typing import AsyncIterator + +import pytest +import pytest_asyncio +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.human import HumanMessage +from langchain_core.messages.system import SystemMessage +from sqlalchemy import text + +from langchain_postgres import PGEngine, PostgresChatMessageHistory +from langchain_postgres.v2.async_chat_message_history import ( + AsyncPGChatMessageHistory, +) +from tests.utils import VECTORSTORE_CONNECTION_STRING, asyncpg_client + +TABLE_NAME = "message_store" + str(uuid.uuid4()) +TABLE_NAME_ASYNC = "message_store" + str(uuid.uuid4()) + + +async def aexecute(engine: PGEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest_asyncio.fixture +async def async_engine() -> AsyncIterator[PGEngine]: + async_engine = PGEngine.from_connection_string(url=VECTORSTORE_CONNECTION_STRING) + await async_engine._ainit_chat_history_table(table_name=TABLE_NAME_ASYNC) + yield async_engine + # use default table for AsyncPGChatMessageHistory + query = f'DROP TABLE IF EXISTS "{TABLE_NAME_ASYNC}"' + await aexecute(async_engine, query) + await async_engine.close() + + +@pytest.mark.asyncio +async def test_chat_message_history_async( + async_engine: PGEngine, +) -> None: + history = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history.aadd_message(msg1) + await history.aadd_message(msg2) + messages = await history._aget_messages() + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + await history.aclear() + assert len(await history._aget_messages()) == 0 + + +@pytest.mark.asyncio +async def test_chat_message_history_sync_messages( + async_engine: PGEngine, +) -> None: + history1 = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + history2 = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history1.aadd_message(msg1) + await history2.aadd_message(msg2) + + assert len(await history1._aget_messages()) == 2 + assert len(await history2._aget_messages()) == 2 + + # verify clear() clears message history + await history2.aclear() + assert len(await history2._aget_messages()) == 0 + + +@pytest.mark.asyncio +async def test_chat_table_async(async_engine: PGEngine) -> None: + with pytest.raises(ValueError): + await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name="doesnotexist" + ) + + +@pytest.mark.asyncio +async def test_v1_schema_support(async_engine: PGEngine) -> None: + table_name = "chat_history" + session_id = str(uuid.UUID(int=125)) + async with asyncpg_client() as async_connection: + await PostgresChatMessageHistory.adrop_table(async_connection, table_name) + await PostgresChatMessageHistory.acreate_tables(async_connection, table_name) + + chat_history = PostgresChatMessageHistory( + table_name, session_id, async_connection=async_connection + ) + + await chat_history.aadd_messages( + [ + SystemMessage(content="Meow"), + AIMessage(content="woof"), + HumanMessage(content="bark"), + ] + ) + + history = await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id=session_id, table_name=table_name + ) + + messages = await history._aget_messages() + + assert len(messages) == 3 + + msg1 = HumanMessage(content="hi!") + await history.aadd_message(msg1) + + messages = await history._aget_messages() + + assert len(messages) == 4 + + await async_engine._adrop_table(table_name=table_name) + + +async def test_incorrect_schema(async_engine: PGEngine) -> None: + table_name = "incorrect_schema_" + str(uuid.uuid4()) + await async_engine._ainit_vectorstore_table(table_name=table_name, vector_size=1024) + with pytest.raises(IndexError): + await AsyncPGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name + ) + query = f'DROP TABLE IF EXISTS "{table_name}"' + await aexecute(async_engine, query) diff --git a/tests/unit_tests/v2/test_chat_message_history.py b/tests/unit_tests/v2/test_chat_message_history.py new file mode 100644 index 00000000..749d9144 --- /dev/null +++ b/tests/unit_tests/v2/test_chat_message_history.py @@ -0,0 +1,184 @@ +import uuid +from typing import Any, AsyncIterator + +import pytest +import pytest_asyncio +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.human import HumanMessage +from sqlalchemy import text + +from langchain_postgres import PGChatMessageHistory, PGEngine +from tests.utils import VECTORSTORE_CONNECTION_STRING + +TABLE_NAME = "message_store" + str(uuid.uuid4()) +TABLE_NAME_ASYNC = "message_store" + str(uuid.uuid4()) + + +async def aexecute( + engine: PGEngine, + query: str, +) -> None: + async def run(engine: PGEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +@pytest_asyncio.fixture +async def engine() -> AsyncIterator[PGEngine]: + engine = PGEngine.from_connection_string(url=VECTORSTORE_CONNECTION_STRING) + engine.init_chat_history_table(table_name=TABLE_NAME) + yield engine + # use default table for PGChatMessageHistory + query = f'DROP TABLE IF EXISTS "{TABLE_NAME}"' + await aexecute(engine, query) + await engine.close() + + +@pytest_asyncio.fixture +async def async_engine() -> AsyncIterator[PGEngine]: + async_engine = PGEngine.from_connection_string(url=VECTORSTORE_CONNECTION_STRING) + await async_engine.ainit_chat_history_table(table_name=TABLE_NAME_ASYNC) + yield async_engine + # use default table for PGChatMessageHistory + query = f'DROP TABLE IF EXISTS "{TABLE_NAME_ASYNC}"' + await aexecute(async_engine, query) + await async_engine.close() + + +def test_chat_message_history(engine: PGEngine) -> None: + history = PGChatMessageHistory.create_sync( + engine=engine, session_id="test", table_name=TABLE_NAME + ) + history.add_user_message("hi!") + history.add_ai_message("whats up?") + messages = history.messages + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + history.clear() + assert len(history.messages) == 0 + + +def test_chat_table(engine: Any) -> None: + with pytest.raises(ValueError): + PGChatMessageHistory.create_sync( + engine=engine, session_id="test", table_name="doesnotexist" + ) + + +async def test_incorrect_schema_async(async_engine: PGEngine) -> None: + table_name = "incorrect_schema_" + str(uuid.uuid4()) + await async_engine.ainit_vectorstore_table(table_name=table_name, vector_size=1024) + with pytest.raises(IndexError): + await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name + ) + query = f'DROP TABLE IF EXISTS "{table_name}"' + await aexecute(async_engine, query) + + +async def test_incorrect_schema_sync(async_engine: PGEngine) -> None: + table_name = "incorrect_schema_" + str(uuid.uuid4()) + async_engine.init_vectorstore_table(table_name=table_name, vector_size=1024) + with pytest.raises(IndexError): + PGChatMessageHistory.create_sync( + engine=async_engine, session_id="test", table_name=table_name + ) + query = f'DROP TABLE IF EXISTS "{table_name}"' + await aexecute(async_engine, query) + + +@pytest.mark.asyncio +async def test_chat_message_history_async( + async_engine: PGEngine, +) -> None: + history = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history.aadd_message(msg1) + await history.aadd_message(msg2) + messages = history.messages + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + await history.aclear() + assert len(history.messages) == 0 + + +@pytest.mark.asyncio +async def test_chat_message_history_sync_messages( + async_engine: PGEngine, +) -> None: + history1 = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + history2 = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="whats up?") + await history1.aadd_message(msg1) + await history2.aadd_message(msg2) + + assert len(history1.messages) == 2 + assert len(history2.messages) == 2 + + # verify clear() clears message history + await history2.aclear() + assert len(history2.messages) == 0 + + +@pytest.mark.asyncio +async def test_chat_message_history_set_messages( + async_engine: PGEngine, +) -> None: + history = await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=TABLE_NAME_ASYNC + ) + msg1 = HumanMessage(content="hi!") + msg2 = AIMessage(content="bye -_-") + # verify setting messages property adds to message history + history.messages = [msg1, msg2] + assert len(history.messages) == 2 + + +@pytest.mark.asyncio +async def test_chat_table_async(async_engine: PGEngine) -> None: + with pytest.raises(ValueError): + await PGChatMessageHistory.create( + engine=async_engine, session_id="test", table_name="doesnotexist" + ) + + +@pytest.mark.asyncio +async def test_cross_env_chat_message_history(engine: PGEngine) -> None: + history = PGChatMessageHistory.create_sync( + engine=engine, session_id="test_cross", table_name=TABLE_NAME + ) + await history.aadd_message(HumanMessage(content="hi!")) + messages = history.messages + assert messages[0].content == "hi!" + history.clear() + + history = await PGChatMessageHistory.create( + engine=engine, session_id="test_cross", table_name=TABLE_NAME + ) + history.add_message(HumanMessage(content="hi!")) + messages = history.messages + assert messages[0].content == "hi!" + history.clear()