Agent for MongoDB | Langchain

Omer Shahzad
6 min readJan 5, 2024

--

OpenAI has made its API available for exploration, allowing users to tailor it to their specific needs. In this context, LangChain plays a crucial role by providing a wrapper that simplifies coding and enhances the usability of these tools. It does so by offering multiple agents, loaders, and more, thereby reducing the overall coding effort.

It’s well-known that there’s an agent capable of efficiently handling SQL databases, streamlining the process of fetching relevant data. However, when it comes to No-SQL databases like MongoDB, there currently isn’t a dedicated agent offering this facility.

This lack creates a significant challenge in retrieving relevant data from these databases, especially when every query must be crafted manually. This process becomes particularly time-consuming and cumbersome in large-scale projects or in scenarios where the required data depends entirely on user input.

LangChain

In such situations, the need for a “custom agent” becomes evident. This agent would specialize in managing MongoDB queries, adapting dynamically to user inputs, and efficiently retrieving the necessary data without the overhead of manual query generation. This innovation could significantly enhance data retrieval processes, particularly in complex and user-driven projects.

These are the packages you need to install in order to create this agent.

    pip install openai 
pip install langchai
pip install python-dotenv
pip install pymongo

Setting BaseChain

Set the BaseChain which can initialize with theModel name ,OpenAI key and Temprature . Also, we have to define the function which we are going to use in our query builder chain.

chains -> base_chain.py

import os
from langchain.llms.openai import OpenAI

class BaseChain(ABC):

llm=None
"""OpenAI model"""

def __init__(self, model_name, temperature, **kwrgs) -> None:
self.llm = OpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
model_name=model_name,
temperature=temperature,
)

@abstractproperty
def template() -> str:
"""Template String that is required to build PromptTemplate"""
pass

@abstractproperty
def input_variables() -> List[str]:
"""Array of strings that are required as variable inputs for placeholders in template"""
pass

@abstractstaticmethod
def populate_input_variables(payload) -> dict:
"""Populate the input variables by using the provided payload"""
pass

@abstractmethod
def partial_variables(self) -> dict:
"""Populate partial variables while creating the chain"""
pass

@abstractmethod
def chain(self) -> LLMChain:
"""Create and return LLM Chain"""
pass

After this step, we need to create the basic structure of our chain. The construction of our chain requires the following components:

  1. Input Variable
  2. Partial Variable
  3. Prompt Template
  4. Output Parser

To separate the user query from the input variable, we can use various kinds of delimiters, such as “#” or “$”, among others.

All the code till “Running the chain are the part ofquery_builder_chain.py

chains -> query_builder_chain -> query_builder_chain.py

Output Parser

Create the output parser and define the type of output you want the chain to respond with.

The definition of the key is important as it influences the response

from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field

class Response(BaseModel):
query: list = Field(description="It must be a list")

def to_dict(self):
return self.query

parser = PydanticOutputParser(pydantic_object=Response)

Table Schema and Schema Description

Make sure this is the main part of the chain we have to pass the table schema and the sechma description.

It’s crucial to understand that the model might not recognize the specific keys you’ve defined in your table. This challenge was something I faced, and the solution I discovered involves passing the schema along with an essential description of the schema.

Since the example provided is just a placeholder, you can include as many details as needed in the schema description to help the model understand. It’s essential to ensure that this information is well-formatted.

You can define the schema and its description in a separate file, and then load it into main file. This approach helps keep the code clean and makes it reusable.

table_schema = """
{
"user_id": "string",
"age": "string",
"gender": "string",
"height": "string"
}
"""
schema_description = """

Here is the description to determine what each key represents
1. **userId**:
- Description: Unique identifier for the user.
2. **age**:
- Description: User's age in years.
3. **gender**:
- Description: User's gender (male/female).
4. **race**:
- Description: User's racial or ethnic background.
5. **height**:
- Description: User's height.
"""

MongoDB Agent

Let’s create a main class where we can define our prompt template and input variable.

First, initialize the chain with model_name, temperature, and max_tokens. Then, create the template for the chain, which will be responsible for creating the query. I suggest using the MongoDB aggregation pipeline as it is versatile and can be applied to a wide range of queries.

from chains.base_chain import BaseChain

class MongoDBQueryAgent(BaseChain):
def __init__(self) -> None:
super().__init__(
model_name=os.getenv("MODEL_Name"),
max_tokens=int(os.getenv("DEFAULT_COMPLETION_TOKEN_LIMIT")),
temperature=int(os.getenv("DEFAULT_TEMPERATURE")),
)

Prompt Template and Input Variable:

First we create the Prompt template:

The term user_message refers to any kind of input from which you wish to generate a MongoDB query.

For example, if the user message is “Get the recent records of users younger than 25”,

1.table_schema is your mongoDB table Schema,

2.schema_description is the textual description we have define to understand the model that every key representation in our DB,

3.formate_instruction is the formate instruction which we define in the output parser this is the reservved word we just define it in the prompt not need to define this in the input variable.

This prompt is highly adaptable and can be tailored to suit your specific requirements.

template = """
Create a MongoDB raw aggregation query for the following user question:
###{user_message}###

This is table schema : "{table_schema}"
This is schema description : ${schema_description}$

#add multiple instruction based on your requirenment
use match statement and inside match don't use `id` instead of `id` use "user_id" with the value of "{user_id}"

Just return the [] of aggregration pipeline.
The following is the format instructions.
***{format_instructions}***
"""

input_variables = [
"user_message",
"schema_description",
"table_schema"
"user_id"
]

Main Chain

After defining these, now create the functions that populate the input variable. Utilize the output parse and main chain functions, which can incorporate all these functions and input to provide the desired response.

The chain function returns the chain with all the input variables populated. Add the template defined above and the LLM (Large Language Model) defined in theBaseChain .


def populate_input_variables(payload: dict) -> dict:
return {
"user_message": payload["question"],
"schema_description": guidelines,
"table_schema" : table_schema
}

def partial_variables(self) -> dict:
return {"format_instructions": parser.get_format_instructions()}

def chain(self) -> LLMChain:
chain_prompt = PromptTemplate(
input_variables=self.input_variables,
template=self.template,
partial_variables=self.populate_partial_variables(),
)
queryBuilderAgent = LLMChain(
llm=self.llm,
prompt=chain_prompt,
)
return queryBuilderAgent

Complete Code

Here is all the code combine of the MongoDB Agent

chains -> query_builder_chain -> query_builder_chain.py

import os

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
from chains.base_chain import BaseChain

class Response(BaseModel):
query: list = Field(description="It must be a list")

def to_dict(self):
return self.query

parser = PydanticOutputParser(pydantic_object=Response)

table_schema = """
{
"user_id": "string",
"age": "string",
"gender": "string",
"height": "string"
}
"""

schema_description = """

Here is the description to determine what each key represents
1. **userId**:
- Description: Unique identifier for the user.
2. **age**:
- Description: User's age in years.
3. **gender**:
- Description: User's gender (male/female).
4. **race**:
- Description: User's racial or ethnic background.
5. **height**:
- Description: User's height.
"""

class MongoDBQueryAgent(BaseChain):
def __init__(self) -> None:
super().__init__(
model_name=os.getenv("MODEL_Name"),
max_tokens=int(os.getenv("DEFAULT_COMPLETION_TOKEN_LIMIT")),
temperature=int(os.getenv("DEFAULT_TEMPERATURE")),
)

template = """
Create a MongoDB raw aggregation query for the following user question:
###{user_message}###

This is table schema : "{table_schema}"
This is schema description : ${schema_description}$

#add multiple instruction based on your requirenment
use match statement and inside match don't use `id` instead of `id` use "user_id" with the value of "{user_id}"

Just return the [] of aggregration pipeline.
The following is the format instructions.
***{format_instructions}***
"""

input_variables = [
"user_message",
"schema_description",
"table_schema"
"user_id"
]

def populate_input_variables(payload: dict) -> dict:
return {
"user_message": payload["question"],
"user_id": payload["user_id"],
"schema_description": guidelines,
"table_schema" : table_schema
}

def partial_variables(self) -> dict:
return {"format_instructions": parser.get_format_instructions()}

def chain(self) -> LLMChain:
chain_prompt = PromptTemplate(
input_variables=self.input_variables,
template=self.template,
partial_variables=self.populate_partial_variables(),
)
queryBuilderAgent = LLMChain(
llm=self.llm,
prompt=chain_prompt,
)
return queryBuilderAgent

Run the Chain

The final part of the above code snippet is to execute the chain and obtain the query.

main.py

from chains.query_builder_chain.query_builder_chain import  QueryBuilderChain
from chains.query_builder_chain.query_builder_chain import parser as queryBuilderParser

query_builder_chain = QueryBuilderChain().chain()
query_builder_chain_input = QueryBuilderChain.populate_variables({
"question" : "Get the user's who's age is less then equal to 25",
"user_id" : "any_mongo_db_id"
})
query_builder_chain_response = query_builder_chain.run(query_builder_chain_input)
query_builder_parsed_response = queryBuilderParser.parse(query_builder_chain_response).to_dict()

print ("MongoDB Raw Query" , query_builder_parsed_response)

This will Return the MongoDB Aggreation Pipeline [] only you can then use this with your python code.

--

--

Omer Shahzad

Software Engineer | MERN Stack Developer | AI | Langchain | OpenAI