Tuesday, May 14, 2024

Defining Custom Tools in Langchain

When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:


name (str), is required and must be unique within a set of tools provided to an agent

description (str), is optional but recommended, as it is used by an agent to determine tool use

args_schema (Pydantic BaseModel), is optional but recommended, can be used to provide more information (e.g., few-shot examples) or validation for expected parameters.


# Import things that are needed generically

from langchain.pydantic_v1 import BaseModel, Field

from langchain.tools import BaseTool, StructuredTool, tool


from typing import Optional, Type


from langchain.callbacks.manager import (

    AsyncCallbackManagerForToolRun,

    CallbackManagerForToolRun,

)


class SearchInput(BaseModel):

    query: str = Field(description="should be a search query")



class CalculatorInput(BaseModel):

    a: int = Field(description="first number")

    b: int = Field(description="second number")



class CustomSearchTool(BaseTool):

    name = "custom_search"

    description = "useful for when you need to answer questions about current events"

    args_schema: Type[BaseModel] = SearchInput


    def _run(

        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None

    ) -> str:

        """Use the tool."""

        return "LangChain"


    async def _arun(

        self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None

    ) -> str:

        """Use the tool asynchronously."""

        raise NotImplementedError("custom_search does not support async")



class CustomCalculatorTool(BaseTool):

    name = "Calculator"

    description = "useful for when you need to answer questions about math"

    args_schema: Type[BaseModel] = CalculatorInput

    return_direct: bool = True


    def _run(

        self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None

    ) -> str:

        """Use the tool."""

        return a * b


    async def _arun(

        self,

        a: int,

        b: int,

        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,

    ) -> str:

        """Use the tool asynchronously."""

        raise NotImplementedError("Calculator does not support async")


No comments:

Post a Comment