Agent for MongoDB | Langchain
OpenAI has made its API available for exploration, allowing users to tailor it to their specific needs. In this context, LangChain plays a crucial role by providing a wrapper that simplifies coding and enhances the usability of these tools. It does so by offering multiple agents, loaders, and more, thereby reducing the overall coding effort.
It’s well-known that there’s an agent capable of efficiently handling SQL databases, streamlining the process of fetching relevant data. However, when it comes to No-SQL databases like MongoDB, there currently isn’t a dedicated agent offering this facility.
This lack creates a significant challenge in retrieving relevant data from these databases, especially when every query must be crafted manually. This process becomes particularly time-consuming and cumbersome in large-scale projects or in scenarios where the required data depends entirely on user input.
In such situations, the need for a “custom agent” becomes evident. This agent would specialize in managing MongoDB queries, adapting dynamically to user inputs, and efficiently retrieving the necessary data without the overhead of manual query generation. This innovation could significantly enhance data retrieval processes, particularly in complex and user-driven projects.
These are the packages you need to install in order to create this agent.
pip install openai
pip install langchai
pip install python-dotenv
pip install pymongo
Setting BaseChain
Set the BaseChain which can initialize with theModel name
,OpenAI key
and Temprature
. Also, we have to define the function which we are going to use in our query builder chain.
chains -> base_chain.py
import os
from langchain.llms.openai import OpenAI
class BaseChain(ABC):
llm=None
"""OpenAI model"""
def __init__(self, model_name, temperature, **kwrgs) -> None:
self.llm = OpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
model_name=model_name,
temperature=temperature,
)
@abstractproperty
def template() -> str:
"""Template String that is required to build PromptTemplate"""
pass
@abstractproperty
def input_variables() -> List[str]:
"""Array of strings that are required as variable inputs for placeholders in template"""
pass
@abstractstaticmethod
def populate_input_variables(payload) -> dict:
"""Populate the input variables by using the provided payload"""
pass
@abstractmethod
def partial_variables(self) -> dict:
"""Populate partial variables while creating the chain"""
pass
@abstractmethod
def chain(self) -> LLMChain:
"""Create and return LLM Chain"""
pass
After this step, we need to create the basic structure of our chain. The construction of our chain requires the following components:
- Input Variable
- Partial Variable
- Prompt Template
- Output Parser
To separate the user query from the input variable, we can use various kinds of delimiters, such as “#” or “$”, among others.
All the code till “Running the chain are the part of
query_builder_chain.py
”
chains -> query_builder_chain -> query_builder_chain.py
Output Parser
Create the output parser and define the type of output you want the chain to respond with.
The definition of the key is important as it influences the response
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
class Response(BaseModel):
query: list = Field(description="It must be a list")
def to_dict(self):
return self.query
parser = PydanticOutputParser(pydantic_object=Response)
Table Schema and Schema Description
Make sure this is the main part of the chain we have to pass the table schema and the sechma description.
It’s crucial to understand that the model might not recognize the specific keys you’ve defined in your table. This challenge was something I faced, and the solution I discovered involves passing the schema along with an essential description of the schema.
Since the example provided is just a placeholder, you can include as many details as needed in the schema description to help the model understand. It’s essential to ensure that this information is well-formatted.
You can define the schema and its description in a separate file, and then load it into main file. This approach helps keep the code clean and makes it reusable.
table_schema = """
{
"user_id": "string",
"age": "string",
"gender": "string",
"height": "string"
}
"""
schema_description = """
Here is the description to determine what each key represents
1. **userId**:
- Description: Unique identifier for the user.
2. **age**:
- Description: User's age in years.
3. **gender**:
- Description: User's gender (male/female).
4. **race**:
- Description: User's racial or ethnic background.
5. **height**:
- Description: User's height.
"""
MongoDB Agent
Let’s create a main class where we can define our prompt template
and input variable
.
First, initialize the chain with model_name
, temperature
, and max_tokens
. Then, create the template for the chain, which will be responsible for creating the query. I suggest using the MongoDB aggregation pipeline as it is versatile and can be applied to a wide range of queries.
from chains.base_chain import BaseChain
class MongoDBQueryAgent(BaseChain):
def __init__(self) -> None:
super().__init__(
model_name=os.getenv("MODEL_Name"),
max_tokens=int(os.getenv("DEFAULT_COMPLETION_TOKEN_LIMIT")),
temperature=int(os.getenv("DEFAULT_TEMPERATURE")),
)
Prompt Template and Input Variable:
First we create the Prompt template:
The term user_message
refers to any kind of input from which you wish to generate a MongoDB query.
For example, if the user message is “Get the recent records of users younger than 25”,
1.table_schema
is your mongoDB table Schema,
2.schema_description
is the textual description we have define to understand the model that every key representation in our DB,
3.formate_instruction
is the formate instruction which we define in the output parser this is the reservved word we just define it in the prompt not need to define this in the input variable.
This prompt is highly adaptable and can be tailored to suit your specific requirements.
template = """
Create a MongoDB raw aggregation query for the following user question:
###{user_message}###
This is table schema : "{table_schema}"
This is schema description : ${schema_description}$
#add multiple instruction based on your requirenment
use match statement and inside match don't use `id` instead of `id` use "user_id" with the value of "{user_id}"
Just return the [] of aggregration pipeline.
The following is the format instructions.
***{format_instructions}***
"""
input_variables = [
"user_message",
"schema_description",
"table_schema"
"user_id"
]
Main Chain
After defining these, now create the functions that populate the input variable. Utilize the output parse and main chain functions, which can incorporate all these functions and input to provide the desired response.
The chain function returns the chain with all the input variables populated. Add the template defined above and the LLM (Large Language Model)
defined in theBaseChain
.
def populate_input_variables(payload: dict) -> dict:
return {
"user_message": payload["question"],
"schema_description": guidelines,
"table_schema" : table_schema
}
def partial_variables(self) -> dict:
return {"format_instructions": parser.get_format_instructions()}
def chain(self) -> LLMChain:
chain_prompt = PromptTemplate(
input_variables=self.input_variables,
template=self.template,
partial_variables=self.populate_partial_variables(),
)
queryBuilderAgent = LLMChain(
llm=self.llm,
prompt=chain_prompt,
)
return queryBuilderAgent
Complete Code
Here is all the code combine of the MongoDB Agent
chains -> query_builder_chain -> query_builder_chain.py
import os
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
from chains.base_chain import BaseChain
class Response(BaseModel):
query: list = Field(description="It must be a list")
def to_dict(self):
return self.query
parser = PydanticOutputParser(pydantic_object=Response)
table_schema = """
{
"user_id": "string",
"age": "string",
"gender": "string",
"height": "string"
}
"""
schema_description = """
Here is the description to determine what each key represents
1. **userId**:
- Description: Unique identifier for the user.
2. **age**:
- Description: User's age in years.
3. **gender**:
- Description: User's gender (male/female).
4. **race**:
- Description: User's racial or ethnic background.
5. **height**:
- Description: User's height.
"""
class MongoDBQueryAgent(BaseChain):
def __init__(self) -> None:
super().__init__(
model_name=os.getenv("MODEL_Name"),
max_tokens=int(os.getenv("DEFAULT_COMPLETION_TOKEN_LIMIT")),
temperature=int(os.getenv("DEFAULT_TEMPERATURE")),
)
template = """
Create a MongoDB raw aggregation query for the following user question:
###{user_message}###
This is table schema : "{table_schema}"
This is schema description : ${schema_description}$
#add multiple instruction based on your requirenment
use match statement and inside match don't use `id` instead of `id` use "user_id" with the value of "{user_id}"
Just return the [] of aggregration pipeline.
The following is the format instructions.
***{format_instructions}***
"""
input_variables = [
"user_message",
"schema_description",
"table_schema"
"user_id"
]
def populate_input_variables(payload: dict) -> dict:
return {
"user_message": payload["question"],
"user_id": payload["user_id"],
"schema_description": guidelines,
"table_schema" : table_schema
}
def partial_variables(self) -> dict:
return {"format_instructions": parser.get_format_instructions()}
def chain(self) -> LLMChain:
chain_prompt = PromptTemplate(
input_variables=self.input_variables,
template=self.template,
partial_variables=self.populate_partial_variables(),
)
queryBuilderAgent = LLMChain(
llm=self.llm,
prompt=chain_prompt,
)
return queryBuilderAgent
Run the Chain
The final part of the above code snippet is to execute the chain and obtain the query.
main.py
from chains.query_builder_chain.query_builder_chain import QueryBuilderChain
from chains.query_builder_chain.query_builder_chain import parser as queryBuilderParser
query_builder_chain = QueryBuilderChain().chain()
query_builder_chain_input = QueryBuilderChain.populate_variables({
"question" : "Get the user's who's age is less then equal to 25",
"user_id" : "any_mongo_db_id"
})
query_builder_chain_response = query_builder_chain.run(query_builder_chain_input)
query_builder_parsed_response = queryBuilderParser.parse(query_builder_chain_response).to_dict()
print ("MongoDB Raw Query" , query_builder_parsed_response)
This will Return the MongoDB Aggreation Pipeline [] only you can then use this with your python code.