Skip to content

Instantly share code, notes, and snippets.

@tecnologer
Last active January 31, 2024 17:36
Show Gist options
  • Select an option

  • Save tecnologer/1ec9507bb5e7e440ea31deef19294079 to your computer and use it in GitHub Desktop.

Select an option

Save tecnologer/1ec9507bb5e7e440ea31deef19294079 to your computer and use it in GitHub Desktop.
Factory to build Functions objects for OpenAI API
package agent
import (
"context"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
)
// Chatter is an interface for interacting with the AI
type Chatter interface {
Call(
ctx context.Context,
messages []llms.MessageContent,
options ...llms.CallOption,
) (*llms.ContentResponse, error)
}
// Chat is an implementation of Chatter that uses the OpenAI API
type Chat struct {
chat *openai.LLM
functions *factory.Functions
}
// NewChat creates a new Chat instance
func NewChat(model, apiKey string, functions *factory.Functions) (*Chat, error) {
llm, err := openai.New(
openai.WithModel(model),
openai.WithToken(apiKey),
)
if err != nil {
return nil, errors.Wrap(err, "failed to create chat")
}
return &Chat{
chat: llm,
functions: functions,
}, nil
}
// Call calls the AI's chat with the given messages
func (c *Chat) Call(
ctx context.Context,
messages []llms.MessageContent,
options ...llms.CallOption,
) (*llms.ContentResponse, error) {
if c.functions != nil && !c.optionsContainsFunctions(options) {
options = append(options, llms.WithFunctions(c.functions.Definitions))
}
completion, err := c.chat.GenerateContent(
ctx,
messages,
options...,
)
if err != nil {
return nil, errors.Wrap(err, "failed to call chat")
}
if len(completion.Choices) < 1 || completion.Choices[0].FuncCall == nil {
return completion, nil
}
completion, err = c.callFunction(ctx, messages, completion.Choices[0].FuncCall)
if err != nil {
return nil, errors.Wrap(err, "failed to call function")
}
return completion, nil
}
func (c *Chat) callFunction(
ctx context.Context,
messages []llms.MessageContent,
functionCall *schema.FunctionCall,
) (*llms.ContentResponse, error) {
nlogger.Debugf(">> call function %s with arguments: %s", functionCall.Name, functionCall.Arguments)
content, err := c.functions.Exec(functionCall)
if err != nil {
nlogger.Warningf(">> call function %s. Err: %v", functionCall.Name, err)
errContent, err := json.Marshal(err.Error())
if err != nil {
return nil, errors.Wrap(err, "failed to marshal error")
}
messages = append(messages,
FunctionMessagef(functionCall, "function executed with error: %s", string(errContent)),
)
} else {
messages = append(messages,
FunctionMessagef(functionCall, "function executed correctly, result: %s", content),
)
}
completion, err := c.Call(
ctx,
messages,
llms.WithFunctions(c.functions.Definitions),
)
if err != nil {
return nil, errors.Wrap(err, "failed to call chat")
}
return completion, nil
}
func (c *Chat) optionsContainsFunctions(options []llms.CallOption) bool {
for _, option := range options {
if factory.FunctionNameFromHandler(option) != "WithFunctions" {
return true
}
}
return false
}
func FunctionMessagef(funcCall *schema.FunctionCall, format string, a ...any) llms.MessageContent {
return TextMessage(
schema.ChatMessageTypeAI,
fmt.Sprintf("name: %s", funcCall.Name),
fmt.Sprintf(format, a...),
)
}
func HumaneMessage(content string) llms.MessageContent {
return TextMessage(schema.ChatMessageTypeHuman, content)
}
func TextMessage(role schema.ChatMessageType, content ...string) llms.MessageContent {
msg := llms.MessageContent{
Role: role,
Parts: []llms.ContentPart{},
}
for _, c := range content {
msg.Parts = append(msg.Parts, llms.TextContent{
Text: c,
})
}
return msg
}
func TextMessagef(role schema.ChatMessageType, format string, a ...any) llms.MessageContent {
return TextMessage(role, fmt.Sprintf(format, a...))
}
package factory
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
"github.com/tmc/langchaingo/jsonschema"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"reflect"
"runtime"
"sort"
"strings"
"sync"
)
type (
FunctionFactoryOption func(function *Function)
HandlerOption func(function *Function)
)
// Function is a function definition and its handler, contains the function's arguments and receivers
type Function struct {
Definition llms.FunctionDefinition // The function definition for the LLM
args Parameters // The list of arguments for the function
Handler any // The function handler, used to execute the function
receivers []any // The function handler can return any number of values, the receivers will be used to store the results
autoExec bool // If true, the function will be executed automatically by the AI, default is true
}
// Functions is a collection of functions
type Functions struct {
m sync.RWMutex
handlers map[string]*Function // map of function name to function, used to quickly find a function by name and execute it
Definitions []llms.FunctionDefinition // The list of function definitions for the LLM
}
// Exec executes the function's handler with the given arguments, and stores the results in the receivers
//
// Parameters
// - fnCall: the function call sent by the AI
//
// If the function handler returns an error, it will be returned by Exec
func (f *Function) Exec(fnCall *schema.FunctionCall) error {
fnVal := reflect.ValueOf(f.Handler)
if fnVal.Kind() != reflect.Func {
return fmt.Errorf("expected a function, got %s", fnVal.Kind())
}
args, err := f.buildCallArgs(fnCall)
if err != nil {
return err
}
err = f.setValuesToReceivers(fnVal.Call(args))
if err != nil {
return err
}
return nil
}
// buildCallArgs builds the function's arguments from the given function call
func (f *Function) buildCallArgs(fnCall *schema.FunctionCall) ([]reflect.Value, error) {
var params map[string]any
if err := json.Unmarshal([]byte(fnCall.Arguments), &params); err != nil {
return nil, err
}
if len(f.args) == 0 {
return nil, nil
}
arguments := make([]reflect.Value, len(f.args))
sort.Sort(f.args)
for i, arg := range f.args {
if arg.Type == jsonschema.Object {
objectBody, err := json.Marshal(params[arg.Name])
if err != nil {
return nil, errors.Wrap(err, "failed to marshal object body")
}
// parses the parameter to the instance of the argument
err = json.Unmarshal(objectBody, &arg.Instance)
if err != nil {
return nil, err
}
arguments[i] = reflect.ValueOf(arg.Instance)
continue
}
value, exists := params[arg.Name]
if !exists {
value = arg.DefaultValue
if value == nil {
value = defaultParamValue(arg.Type)
}
}
arguments[i] = reflect.ValueOf(value)
if arg.IsVariadic {
if arguments[i].Len() < 1 {
// remove the last argument
arguments = append(arguments[:i], arguments[i+1:]...)
continue
}
variadicValues := arguments[i]
for vi := 0; vi < arguments[i].Len(); i++ {
arguments[i+vi] = variadicValues.Index(vi)
}
}
if arg.GoType == arguments[i].Type() || arg.GoType == nil {
continue
}
if !arguments[i].Type().ConvertibleTo(arg.GoType) {
continue
}
arguments[i] = arguments[i].Convert(arg.GoType)
}
return arguments, nil
}
// setValuesToReceivers sets the function's handler results to the receivers
func (f *Function) setValuesToReceivers(result []reflect.Value) error {
if len(f.receivers) == 0 {
return nil
}
if len(result) != len(f.receivers) {
return fmt.Errorf("expected %d results, got %d", len(f.receivers), len(result))
}
for i, receiver := range f.receivers {
if reflect.ValueOf(receiver).IsZero() {
continue
}
if e, isErr := result[i].Interface().(error); isErr && !nerror.IsNil(e) {
return errors.Wrapf(
e,
"function '%s' returned an error",
f.Definition.Name,
)
}
receiverKind := reflect.TypeOf(receiver).Elem().Kind()
if receiverKind != reflect.Struct && receiverKind != reflect.Pointer {
reflect.ValueOf(f.receivers[i]).Elem().Set(result[i])
continue
}
// parses the result[i] as json to the instance of the receiver
var data []byte
switch r := result[i].Interface().(type) {
case string:
data = []byte(r)
case []byte:
data = r
case error, nil:
reflect.ValueOf(f.receivers[i]).Elem().Set(result[i])
continue
default:
reflect.ValueOf(f.receivers[i]).Elem().Set(result[i])
continue
}
if len(data) == 0 {
continue
}
err := json.Unmarshal(data, receiver)
if err != nil {
return err
}
}
return nil
}
// NewFunction creates a new FunctionDefinition and applies the options to it
func NewFunction(options ...FunctionFactoryOption) *Function {
fn := &Function{
Definition: llms.FunctionDefinition{},
autoExec: true,
}
for _, opt := range options {
opt(fn)
}
return fn
}
// NewFunctionFromDef parses JSON to function instance and applies the options to it
func NewFunctionFromDef(fn llms.FunctionDefinition) *Function {
function := &Function{
Definition: fn,
}
var (
args = fn.Parameters.(map[string]interface{})
)
buildProperties(&function.args, args)
return function
}
func buildProperties(funcArgs *Parameters, args map[string]interface{}) {
var (
properties = args["properties"].(map[string]interface{})
argType string
argDesc string
order int
required []any
)
if args["required"] != nil {
required = args["required"].([]any)
}
for name, arg := range properties {
argMap := arg.(map[string]interface{})
argType = "string"
if argMap["type"] != nil {
argType = argMap["type"].(string)
}
argDesc = ""
if argMap["description"] != nil {
argDesc = argMap["description"].(string)
}
if argMap["order"] != nil {
order = int(argMap["order"].(float64))
}
var enum []string
if argMap["enum"] != nil {
enumList := argMap["enum"].([]any)
enum = make([]string, len(enumList))
for i, e := range enumList {
enum[i] = e.(string)
}
}
parameter := &Parameter{
Name: name,
Description: argDesc,
Type: jsonschema.DataType(argType),
IsRequired: isArgRequired(name, required),
Enum: enum,
Order: order,
}
if props, ok := argMap["properties"]; ok && props != nil {
if prop := props.(map[string]interface{}); len(prop) > 0 {
buildProperties(&parameter.Properties, argMap)
}
}
*funcArgs = append(*funcArgs, parameter)
}
}
func isArgRequired(argName string, required []any) bool {
for _, arg := range required {
if arg.(string) == argName {
return true
}
}
return false
}
func WithFunctionName(name string) FunctionFactoryOption {
return func(function *Function) {
function.Definition.Name = name
}
}
func WithFunctionDescription(description string) FunctionFactoryOption {
return func(function *Function) {
function.Definition.Description = description
}
}
func WithFunctionParameters(parameters ...*Parameter) FunctionFactoryOption {
return func(function *Function) {
params, err := ParametersFactory(parameters...)
if err != nil {
nlogger.Warnf("error creating function parameters.go: Err: %v", err)
}
function.Definition.Parameters = params
function.args = parameters
}
}
func WithFunctionAutoExec(autoExec bool) FunctionFactoryOption {
return func(function *Function) {
function.autoExec = autoExec
}
}
// WithFunctionHandler sets the function handler, receivers and function's name if not set
//
// Parameters
// - Handler: the function handler
// - Receivers: the function handler can return any number of values, the receivers will be used to store the results
//
// If the function name is not set, it will be inferred from the handler's name
func WithFunctionHandler(handler any, opts ...HandlerOption) FunctionFactoryOption {
return func(function *Function) {
if function.Definition.Name == "" {
WithFunctionName(FunctionNameFromHandler(handler))(function)
}
function.Handler = handler
for _, opt := range opts {
opt(function)
}
}
}
// WithReceivers sets the function receivers
func WithReceivers(receivers ...any) HandlerOption {
return func(function *Function) {
function.receivers = receivers
}
}
// FunctionNameFromHandler returns the function name from the handler's name
func FunctionNameFromHandler(handler any) string {
handlersName := runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name()
functionNameChunks := strings.Split(handlersName, ".")
return functionNameChunks[len(functionNameChunks)-1]
}
// NewFunctions creates a new Functions collection
func NewFunctions() *Functions {
return &Functions{
handlers: map[string]*Function{},
}
}
// Add appends a function to the Functions collection
func (f *Functions) Add(fn *Function) {
f.m.Lock()
defer f.m.Unlock()
f.handlers[fn.Definition.Name] = fn
f.Definitions = append(f.Definitions, fn.Definition)
}
// Exec executes the function related to the given function call
func (f *Functions) Exec(fnCall *schema.FunctionCall) (string, error) {
f.m.RLock()
defer f.m.RUnlock()
funcName := fnCall.Name
fn, ok := f.handlers[funcName]
if !ok {
return "", fmt.Errorf("function %s not found", funcName)
}
if fn.Handler == nil {
return "", fmt.Errorf("function %s has no handler", funcName)
}
err := fn.Exec(fnCall)
if err != nil {
return "", errors.Wrapf(err, "failed to execute function %s", funcName)
}
content, err := json.Marshal(fn.receivers)
if err != nil {
return "", errors.Wrapf(err, "failed to marshal function %s result", funcName)
}
return string(content), nil
}
// SetHandler sets the function handler for the given function name
func (f *Functions) SetHandler(funcName string, handler any, opts ...HandlerOption) {
f.m.RLock()
defer f.m.RUnlock()
fn, ok := f.handlers[funcName]
if !ok {
return
}
WithFunctionHandler(handler, opts...)(fn)
}
func (f *Functions) SetReceivers(funcName string, receivers ...any) {
f.m.RLock()
defer f.m.RUnlock()
fn, ok := f.handlers[funcName]
if !ok {
return
}
WithReceivers(receivers...)(fn)
}
// Merge merges the given Functions collection into the current one
func (f *Functions) Merge(fn *Functions) {
f.m.Lock()
defer f.m.Unlock()
for name, handler := range fn.handlers {
f.handlers[name] = handler
}
f.Definitions = append(f.Definitions, fn.Definitions...)
}
package factory
import (
"encoding/json"
"github.com/pkg/errors"
"github.com/tmc/langchaingo/jsonschema"
"github.com/tmc/langchaingo/llms"
"path"
"reflect"
)
var (
currentPath, _ = file.CallerDir()
defaultInputPath = path.Join(currentPath, "functions.json")
)
// LoadFromDefaultJSON loads the functions from the default JSON file ./src/data/agent/factory/functions.json
func LoadFromDefaultJSON() (*Functions, error) {
return LoadFromJSON(defaultInputPath)
}
// LoadFromJSON loads the functions from the given JSON file
func LoadFromJSON(jsonPath string) (*Functions, error) {
if !file.ExistsFile(jsonPath) {
return nil, errors.Errorf("file %s does not exist", jsonPath)
}
jsonFns, err := readJson(jsonPath)
if err != nil {
return nil, err
}
functions := NewFunctions()
for _, fn := range jsonFns {
functions.Add(NewFunctionFromDef(fn))
}
return functions, nil
}
func readJson(jsonPath string) (functions []llms.FunctionDefinition, _ error) {
if jsonPath == "" {
jsonPath = defaultInputPath
}
content, err := file.ReadFile(jsonPath)
if err != nil {
return nil, err
}
err = json.Unmarshal(content, &functions)
if err != nil {
return nil, err
}
return functions, nil
}
func LoadFunctions(cnn *db.ConnectionHandler, includeJSON bool) (*Functions, error) {
functions := NewFunctions()
if includeJSON {
jsonFns, err := LoadFromDefaultJSON()
if err != nil {
return nil, err
}
functions.Merge(jsonFns)
}
functions.Add(createFunctionRegisterUser(cnn))
functions.Add(createFunctionReadUser(cnn))
functions.Add(createFunctionReadUsers(cnn))
functions.Add(createFunctionUpdateUser(cnn))
functions.Add(createFunctionDeleteUser(cnn))
return functions, nil
}
// createFunctionRegisterUser creates a function to create a user, it will be used to create the user in the database
func createFunctionRegisterUser(cnn *db.ConnectionHandler) *Function {
return NewFunction(
WithFunctionName("RegisterUser"),
WithFunctionParameters(
paramModelUser("user"),
NewParameter(
"sendEmail",
WithParameterDescription(
"flag to indicate if the user should receive an email to confirm his account",
),
WithParameterType(jsonschema.Boolean),
),
),
WithFunctionHandler(
data.NewUserData(cnn).RegisterUser,
),
WithFunctionDescription(
"creates a new user in the DB and in cognito, it returns the user's instance. If the user already exists, it returns an error",
),
)
}
func createFunctionReadUser(cnn *db.ConnectionHandler) *Function {
return NewFunction(
WithFunctionName("ReadUser"),
WithFunctionParameters(
//id uint, includeSoftDeleted bool, associations ...string
NewParameter(
"id",
WithParameterType(jsonschema.Integer),
WithParameterGoType(reflect.TypeOf(uint(0))),
WithParameterDescription(
"The user's ID in the database, it's auto-generated. Required for read, update or delete user.",
),
WithParameterRequired(),
),
NewParameter(
"includeSoftDeleted",
WithParameterType(jsonschema.Boolean),
WithParameterDescription(
"Flag to indicate if the soft deleted users should be included in the result",
),
),
associations("associations", preloads.User()...),
),
WithFunctionHandler(
data.NewUserData(cnn).ReadUser,
),
WithFunctionDescription(
"reads a user from the DB by the user's sub",
),
)
}
func createFunctionReadUsers(cnn *db.ConnectionHandler) *Function {
return NewFunction(
WithFunctionName("ReadUsers"),
WithFunctionParameters(
paginationRequest("req"),
associations("associations", preloads.User()...),
),
WithFunctionHandler(data.NewUserData(cnn).ReadUsers),
)
}
func createFunctionUpdateUser(cnn *db.ConnectionHandler) *Function {
return NewFunction(
WithFunctionName("UpdateUser"),
WithFunctionParameters(
paramModelUser("user"),
),
WithFunctionHandler(data.NewUserData(cnn).UpdateUser),
WithFunctionDescription(
"updates a user in the DB and in cognito, it returns the user's instance.",
),
)
}
func createFunctionDeleteUser(cnn *db.ConnectionHandler) *Function {
return NewFunction(
WithFunctionName("DeleteUser"),
WithFunctionParameters(
paramModelUser("user"),
),
WithFunctionHandler(data.NewUserData(cnn).DeleteUser),
WithFunctionDescription(
"deletes a user from the DB and from cognito, it returns the user's instance.",
),
)
}
// paramModelUser returns a parameter model for the struct models.User
func paramModelUser(name string) *Parameter {
return NewParameter(
name,
WithParameterType(jsonschema.Object),
WithParameterInstance(&models.User{}),
WithParameterRequired(),
WithParameterDescription(
"The user model to manage a user instance in the system. It contains basic information and the permissions to access the system",
),
WithParameters(
NewParameter(
"id",
WithParameterType(jsonschema.Integer),
WithParameterDescription(
"The user's ID in the database, it's auto-generated. Required for read, update or delete user.",
),
),
NewParameter("name", WithParameterRequired(), WithParameterDescription("The user's name")),
NewParameter(
"email",
WithParameterRequired(),
WithParameterDescription(
"The user's email. It's unique, if the email already exists, it will return an error",
),
),
NewParameter(
"sub",
WithParameterDescription(
"The user's sub, this is set by cognito when the user is created in the cognito pool",
),
),
NewParameter(
"sso",
WithParameterType(jsonschema.Boolean),
WithParameterDescription(
"A flag to indicates if the user is using SSO, this is set by business rules",
),
),
NewParameter(
"role",
WithParameterType(jsonschema.Array),
WithParameterDescription(
"The user's roles, it's an array of string where it can contain at least one role from the enum list, for example: [\"AgentUser\", \"OperationsUser\"]",
),
WithParameterRequired(),
WithItems(jsonschema.String,
WithParameterDescription("Each element is one role for the user"),
WithParameterEnum(UserRole.Any.GetAsArray()...),
),
),
),
)
}
func associations(name string, values ...string) *Parameter {
return NewParameter(
"associations",
WithParameterType(jsonschema.Array),
WithParameterIsVariadic(true),
WithParameterDescription(
"Associations (relations in the db) to include in the result",
),
WithItems(jsonschema.String, WithParameterEnum(values...)),
WithParameterDefaultValue([]string{}),
)
}
func paginationRequest(name string) *Parameter {
return NewParameter(
name,
WithParameterType(jsonschema.Object),
WithParameterInstance(&ntypes.PaginatedRequest{}),
WithParameterDescription(
"The paginated request to read the users. Here specify the page size, current page, sort and the filters",
),
WithParameters(
NewParameter(
"page_number",
WithParameterDescription("The current page number"),
WithParameterType(jsonschema.Integer),
WithParameterRequired(),
),
NewParameter(
"page_size",
WithParameterType(jsonschema.Integer),
WithParameterDescription("The page size"),
WithParameterRequired(),
),
NewParameter(
"sort",
WithParameterDescription(
"The sorting rules for the query. Uses the same format of the GORM library (https://gorm.io/docs/query.html#Order). Example: age desc, name",
),
),
NewParameter(
"filters",
WithParameterType(jsonschema.Array),
WithParameterDescription("The list of rules to filter the result"),
WithItems(
jsonschema.Object,
WithParameterDescription("Each element is a filter rule"),
WithParameters(
NewParameter(
"prefix",
WithParameterDescription("prefix of the property, used when the query contains joins"),
),
NewParameter("property", WithParameterDescription("property name")),
NewParameter(
"value",
WithParameterDescription("value to compare, same type of the property type"),
),
NewParameter(
"logic_operator",
WithParameterDescription("logic operator"),
WithParameterEnum("AND", "OR"),
),
NewParameter(
"rel_operator",
WithParameterDescription("relational operator"),
WithParameterEnum(
"=",
">",
">=",
"<",
"<=",
"<>",
"LIKE",
"NOT LIKE",
"LIKE",
"NOT LIKE",
"LIKE",
"NOT LIKE",
"IN",
),
),
),
),
),
),
)
}
package main
import (
"context"
"fmt"
"github.com/sirupsen/logrus"
"github.com/tmc/langchaingo/llms"
)
const OPENAI_API_KEY = "<api_key>"
func main() {
logrus.SetLevel(logrus.DebugLevel)
cnn, err := db.NewConnection()
if err != nil {
panic(err)
}
functions, err := factory.LoadFunctions(cnn, false)
if err != nil {
panic(err)
}
var (
user *models.User
users []*models.User
userErr *nerror.Error
//msg = `create the user jonh doe with the email [email protected] as SuperAdmin. Don't send email confirmation. If there is an error, explain what happened.`
//msg = `who is the user with id 1?. If there is an error, explain what happened.`
//msg = `how many users are there?. If there is an error, explain what happened.`
/*msg = `list the top 10 the users registered in the system sorted by name asc, the format output will be a list of users with the following fields: <sub>: <name> - <email>.
Example: 12063ab6-4c76-465c-a02f-090b4c92ad6e: Michael Scott - [email protected]
If there is an error, explain what happened.`*/
/*msg = `Update the roles to AgentAdmin and OperationsAdmin for the user with id 1,
sub a91c2b21-a4de-4c08-bcd0-35e54a40f6fe, email [email protected] and name Michael Scott.
If there is an error, explain what happened.`*/
msg = `delete the user with id 1. If there is an error, explain what happened.`
)
functions.SetReceivers("RegisterUser", &user, &userErr)
functions.SetReceivers("ReadUser", &user, &userErr)
functions.SetReceivers("ReadUsers", &users, &userErr)
functions.SetReceivers("UpdateUser", &user, &userErr)
functions.SetReceivers("DeleteUser", &user, &userErr)
chat, err := agent.NewChat("gpt-4-1106-preview", OPENAI_API_KEY, functions)
if err != nil {
nlogger.Fatal(err)
}
ctx := context.Background()
completion, err := chat.Call(ctx, []llms.MessageContent{
agent.HumaneMessage(msg),
}, llms.WithFunctions(functions.Definitions))
if err != nil {
nlogger.Fatal(err)
}
fmt.Println(">>>> Agent Response <<<<")
for _, choice := range completion.Choices {
fmt.Println(choice.Content)
}
fmt.Println(">>>> End Agent Response <<<<")
printErr("User", userErr)
if user != nil {
_ = file.WriteJSONFile(user, "user.json")
}
if len(users) > 0 {
_ = file.WriteJSONFile(users, "users.json")
}
}
func printErr(name string, err error) {
if nerror.IsNil(err) {
return
}
fmt.Printf("%s: %s\n", name, err)
}
package factory
import (
"encoding/json"
"fmt"
"github.com/tmc/langchaingo/jsonschema"
"reflect"
)
type ParameterFactoryOption func(parameter *Parameter)
type Parameter struct {
Name string `json:"-"`
IsRequired bool `json:"-"` // Default is false
Instance any `json:"-"` // Instance of the parameter to set value, used if is an object
Type jsonschema.DataType `json:"type"` // Default is string
GoType reflect.Type `json:"-"`
DefaultValue any `json:"-"` // Default is nil
IsVariadic bool `json:"-"` // Default is false
Description string `json:"description,omitempty"`
Enum []string `json:"enum,omitempty"`
Order int `json:"order,omitempty"`
Properties Parameters `json:"properties,omitempty"`
Items *Parameter `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
}
// NewParameter creates a new parameter with the name set and type string by default
//
// options - functions to set the parameter's properties
func NewParameter(name string, options ...ParameterFactoryOption) *Parameter {
parameter := Parameter{
Name: name,
Type: jsonschema.String,
}
for _, opt := range options {
opt(&parameter)
}
return &parameter
}
// WithParameterDescription sets the parameter's description
func WithParameterDescription(description string) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.Description = description
}
}
// WithParameterEnum sets the parameter enum
func WithParameterEnum(enum ...string) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.Enum = enum
}
}
// WithParameterType sets the parameter's type
func WithParameterType(dataType jsonschema.DataType) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.Type = dataType
}
}
// WithParameterTypeString sets the parameter's type, parses string to jsonschema.DataType
func WithParameterTypeString(dataType string) ParameterFactoryOption {
return WithParameterType(jsonType(dataType))
}
// WithParameterRequired sets the parameter as required
func WithParameterRequired() ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.IsRequired = true
}
}
func WithParameterInstance(instance any) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.Instance = instance
}
}
// WithParameters when the parameter is an object, sets the properties
func WithParameters(parameters ...*Parameter) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.Properties = parameters
}
}
// WithItems when the parameter is an array, sets the items as the type passed
func WithItems(t jsonschema.DataType, options ...ParameterFactoryOption) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.Items = NewParameter("items")
options = append(options, WithParameterType(t))
for _, opt := range options {
opt(parameter.Items)
}
}
}
// WithParameterDefaultValue sets the parameter's default value, used if the parameter is not required and is not sent by the agent
func WithParameterDefaultValue(value any) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.DefaultValue = value
}
}
func WithParameterGoType(t reflect.Type) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.GoType = t
}
}
// WithParameterIsVariadic sets the parameter as variadic
func WithParameterIsVariadic(isVariadic bool) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.IsVariadic = isVariadic
}
}
// ParametersFactory creates a properties object as json to be used in the FunctionDefinition
func ParametersFactory(parameters ...*Parameter) (json.RawMessage, error) {
var (
required []string
params = map[string]*Parameter{}
paramsStruct = map[string]any{
"type": "object",
}
)
for _, param := range parameters {
if param.Name == "" {
return nil, fmt.Errorf("param name cannot be empty")
}
if _, exists := params[param.Name]; exists {
return nil, fmt.Errorf("param %s was sent twice", param.Name)
}
if param.IsRequired {
required = append(required, param.Name)
}
if len(param.Properties) > 0 {
for _, subParam := range param.Properties {
if subParam.IsRequired {
param.Required = append(param.Required, subParam.Name)
}
}
}
params[param.Name] = param
}
if len(params) > 0 {
paramsStruct["properties"] = params
}
if len(required) > 0 {
paramsStruct["required"] = required
}
return json.Marshal(paramsStruct)
}
func WithInstance(instance any) ParameterFactoryOption {
return func(parameter *Parameter) {
parameter.Instance = instance
}
}
func defaultParamValue(dataType jsonschema.DataType) any {
switch dataType {
case jsonschema.String:
return ""
case jsonschema.Number, jsonschema.Integer:
return 0
case jsonschema.Boolean:
return false
case jsonschema.Array:
fallthrough
case jsonschema.Object:
fallthrough
default:
return nil
}
}
func jsonType(dataType string) jsonschema.DataType {
switch dataType {
case "string":
return jsonschema.String
case "int",
"int8",
"int16",
"int32",
"int64",
"float32",
"float64",
"uint",
"uint8",
"uint16",
"uint32",
"uint64",
"byte",
"rune":
return jsonschema.Number
case "bool":
return jsonschema.Boolean
case "slice", "array":
return jsonschema.Array
case "map", "struct":
return jsonschema.Object
default:
return jsonschema.String
}
}
type Parameters []*Parameter
func (p Parameters) Len() int {
return len(p)
}
func (p Parameters) Less(i, j int) bool {
return p[i].Order < p[j].Order
}
func (p Parameters) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
func (p Parameters) Get(name string) *Parameter {
for _, param := range p {
if param.Name == name {
return param
}
}
return nil
}
func (p Parameters) SetInstance(name string, instance any) {
param := p.Get(name)
if param == nil {
return
}
param.Instance = instance
}
func (p Parameters) MarshalJSON() ([]byte, error) {
params := map[string]*Parameter{}
for _, param := range p {
params[param.Name] = param
}
return json.Marshal(params)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment