From 821d78a7193713bacc4d61b9eb5a2f9bb9f586ba Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Wed, 24 Jul 2024 22:17:08 -0400 Subject: [PATCH] propagate errors around Signed-off-by: Dave Lee --- core/http/endpoints/openai/chat.go | 10 ++- pkg/functions/function_structure.go | 7 ++- pkg/functions/functions.go | 9 ++- pkg/functions/grammar_json_schema.go | 76 ++++++++++++++++------- pkg/functions/grammar_json_schema_test.go | 41 +++++++----- 5 files changed, 99 insertions(+), 44 deletions(-) diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index f63a991319d..c7afb7bf95e 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -226,9 +226,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup // Update input grammar jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) - config.Grammar = jsStruct.Grammar(config.FunctionsConfig.GrammarConfig.Options()...) + g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarConfig.Options()...) + if err == nil { + config.Grammar = g + } case input.JSONFunctionGrammarObject != nil: - config.Grammar = input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarConfig.Options()...) + g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarConfig.Options()...) + if err == nil { + config.Grammar = g + } default: // Force picking one of the functions by the request if config.FunctionToCall() != "" { diff --git a/pkg/functions/function_structure.go b/pkg/functions/function_structure.go index 650236ec0a8..62cc68fa0c1 100644 --- a/pkg/functions/function_structure.go +++ b/pkg/functions/function_structure.go @@ -13,10 +13,13 @@ type JSONFunctionStructure struct { Defs map[string]interface{} `json:"$defs,omitempty"` } -func (j JSONFunctionStructure) Grammar(options ...func(*GrammarOption)) string { +func (j JSONFunctionStructure) Grammar(options ...func(*GrammarOption)) (string, error) { grammarOpts := &GrammarOption{} grammarOpts.Apply(options...) - dat, _ := json.Marshal(j) + dat, err := json.Marshal(j) + if err != nil { + return "", err + } return NewJSONSchemaConverter(grammarOpts.PropOrder).GrammarFromBytes(dat, options...) } diff --git a/pkg/functions/functions.go b/pkg/functions/functions.go index 4f97f40950d..2690b8ec4cd 100644 --- a/pkg/functions/functions.go +++ b/pkg/functions/functions.go @@ -96,7 +96,10 @@ func (f Functions) Select(name string) Functions { return funcs } -func jsonString(v interface{}) string { - b, _ := json.Marshal(v) - return string(b) +func jsonString(v interface{}) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(b), nil } diff --git a/pkg/functions/grammar_json_schema.go b/pkg/functions/grammar_json_schema.go index 4c958ee7c59..5ffc0ba5e85 100644 --- a/pkg/functions/grammar_json_schema.go +++ b/pkg/functions/grammar_json_schema.go @@ -32,11 +32,15 @@ func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter { } } -func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string { - escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string { +func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) (string, error) { + jLiteral, err := jsonString(literal) + if err != nil { + return "", err + } + escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string { return GRAMMAR_LITERAL_ESCAPES[match] }) - return fmt.Sprintf(`"%s"`, escaped) + return fmt.Sprintf(`"%s"`, escaped), nil } func (sc *JSONSchemaConverter) addRule(name, rule string) string { @@ -140,7 +144,7 @@ func (sc *JSONSchemaConverter) finalizeGrammar(options ...func(*GrammarOption)) return strings.Join(lines, "\n") } -func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) string { +func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) { st, existType := schema["type"] var schemaType string if existType { @@ -159,31 +163,44 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, if oneOfExists { for i, altSchema := range oneOfSchemas { - alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) + alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) + if err != nil { + return "", err + } alternatives = append(alternatives, alternative) } } else if anyOfExists { for i, altSchema := range anyOfSchemas { - alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) + alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema) + if err != nil { + return "", err + } alternatives = append(alternatives, alternative) } } rule := strings.Join(alternatives, " | ") - return sc.addRule(ruleName, rule) + return sc.addRule(ruleName, rule), nil } else if ref, exists := schema["$ref"].(string); exists { referencedSchema := sc.resolveReference(ref, rootSchema) return sc.visit(referencedSchema, name, rootSchema) } else if constVal, exists := schema["const"]; exists { - return sc.addRule(ruleName, sc.formatLiteral(constVal)) + literal, err := sc.formatLiteral((constVal)) + if err != nil { + return "", err + } + return sc.addRule(ruleName, literal), nil } else if enumVals, exists := schema["enum"].([]interface{}); exists { var enumRules []string for _, enumVal := range enumVals { - enumRule := sc.formatLiteral(enumVal) + enumRule, err := sc.formatLiteral(enumVal) + if err != nil { + return "", err + } enumRules = append(enumRules, enumRule) } rule := strings.Join(enumRules, " | ") - return sc.addRule(ruleName, rule) + return sc.addRule(ruleName, rule), nil } else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists { propOrder := sc.propOrder var propPairs []struct { @@ -213,21 +230,30 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, for i, propPair := range propPairs { propName := propPair.propName propSchema := propPair.propSchema - propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema) - + propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema) + if err != nil { + return "", err + } + lPropName, err := sc.formatLiteral(propName) + if err != nil { + return "", err + } if i > 0 { rule.WriteString(` "," space`) } - rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName)) + rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName)) } rule.WriteString(` "}" space`) - return sc.addRule(ruleName, rule.String()) + return sc.addRule(ruleName, rule.String()), nil } else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { - itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema) + itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema) + if err != nil { + return "", err + } rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) - return sc.addRule(ruleName, rule) + return sc.addRule(ruleName, rule), nil } else { primitiveRule, exists := PRIMITIVE_RULES[schemaType] if !exists { @@ -236,7 +262,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, if ruleName == "root" { schemaType = "root" } - return sc.addRule(schemaType, primitiveRule) + return sc.addRule(schemaType, primitiveRule), nil } } func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} { @@ -262,14 +288,20 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin return def } -func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) string { +func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) { sc.addRule("freestring", PRIMITIVE_RULES["freestring"]) - sc.visit(schema, "", schema) - return sc.finalizeGrammar(options...) + _, err := sc.visit(schema, "", schema) + if err != nil { + return "", err + } + return sc.finalizeGrammar(options...), nil } -func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) string { +func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) { var schema map[string]interface{} - _ = json.Unmarshal(b, &schema) + err := json.Unmarshal(b, &schema) + if err != nil { + return "", err + } return sc.Grammar(schema, options...) } diff --git a/pkg/functions/grammar_json_schema_test.go b/pkg/functions/grammar_json_schema_test.go index 6402bb40448..56c5fe1e611 100644 --- a/pkg/functions/grammar_json_schema_test.go +++ b/pkg/functions/grammar_json_schema_test.go @@ -3,7 +3,6 @@ package functions_test import ( "strings" - "github.com/mudler/LocalAI/pkg/functions" . "github.com/mudler/LocalAI/pkg/functions" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -235,7 +234,8 @@ root-1-name ::= "\"search\""` var _ = Describe("JSON schema grammar tests", func() { Context("JSON", func() { It("generates a valid grammar from JSON schema", func() { - grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) + grammar, err := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) + Expect(err).To(BeNil()) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -245,7 +245,8 @@ var _ = Describe("JSON schema grammar tests", func() { Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) }) It("generates a valid grammar from JSON schema", func() { - grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2)) + grammar, err := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2)) + Expect(err).To(BeNil()) results := strings.Split(inputResult3, "\n") for _, r := range results { if r != "" { @@ -259,7 +260,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctions} - grammar := structuredGrammar.Grammar() + grammar, err := structuredGrammar.Grammar() + Expect(err).To(BeNil()) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -273,7 +275,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctions} - grammar := structuredGrammar.Grammar(functions.EnableMaybeArray) + grammar, err := structuredGrammar.Grammar(EnableMaybeArray) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ inputResult2, @@ -291,7 +294,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctionsName} - grammar := structuredGrammar.Grammar(functions.EnableMaybeArray) + grammar, err := structuredGrammar.Grammar(EnableMaybeArray) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ inputResult4, @@ -309,10 +313,11 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctionsName} - grammar := structuredGrammar.Grammar( - functions.SetPrefix("suffix"), - functions.EnableMaybeArray, + grammar, err := structuredGrammar.Grammar( + SetPrefix("suffix"), + EnableMaybeArray, ) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ rootResult(`"suffix" arr | realvalue`), @@ -329,7 +334,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctionsName} - grammar := structuredGrammar.Grammar(functions.SetPrefix("suffix")) + grammar, err := structuredGrammar.Grammar(SetPrefix("suffix")) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ rootResult(`"suffix" realvalue`), @@ -346,7 +352,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctionsName} - grammar := structuredGrammar.Grammar(functions.SetPrefix("suffix"), functions.EnableMaybeString) + grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"), EnableMaybeString) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ rootResult(`( "suffix" realvalue | mixedstring )`), @@ -363,7 +370,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctionsName} - grammar := structuredGrammar.Grammar(functions.SetPrefix("suffix"), functions.EnableMaybeString, functions.EnableMaybeArray) + grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"), EnableMaybeString, EnableMaybeArray) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ rootResult(`( "suffix" (arr | realvalue) | mixedstring )`), @@ -382,7 +390,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctionsName} - grammar := structuredGrammar.Grammar(functions.EnableMaybeString, functions.EnableMaybeArray) + grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ rootResult(`mixedstring | arr | realvalue`), @@ -400,7 +409,8 @@ var _ = Describe("JSON schema grammar tests", func() { structuredGrammar := JSONFunctionStructure{ OneOf: testFunctionsName} - grammar := structuredGrammar.Grammar(functions.EnableMaybeString, functions.EnableMaybeArray, functions.NoMixedFreeString) + grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray, NoMixedFreeString) + Expect(err).To(BeNil()) results := strings.Split( strings.Join([]string{ rootResult(`freestring | arr | realvalue`), @@ -422,7 +432,8 @@ var _ = Describe("JSON schema grammar tests", func() { realvalue ("," realvalue)* )? "]"` - grammar := structuredGrammar.Grammar(functions.EnableMaybeString, functions.EnableMaybeArray, functions.DisableParallelNewLines) + grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray, DisableParallelNewLines) + Expect(err).To(BeNil()) results := strings.Split(content, "\n") for _, r := range results { if r != "" {