Skip to content

Commit

Permalink
propagate errors around
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Lee <[email protected]>
  • Loading branch information
dave-gray101 committed Jul 25, 2024
1 parent ff55f31 commit 821d78a
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 44 deletions.
10 changes: 8 additions & 2 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

// Update input grammar
jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
config.Grammar = jsStruct.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
if err == nil {
config.Grammar = g
}
case input.JSONFunctionGrammarObject != nil:
config.Grammar = input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
if err == nil {
config.Grammar = g
}
default:
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
Expand Down
7 changes: 5 additions & 2 deletions pkg/functions/function_structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ type JSONFunctionStructure struct {
Defs map[string]interface{} `json:"$defs,omitempty"`
}

func (j JSONFunctionStructure) Grammar(options ...func(*GrammarOption)) string {
func (j JSONFunctionStructure) Grammar(options ...func(*GrammarOption)) (string, error) {
grammarOpts := &GrammarOption{}
grammarOpts.Apply(options...)

dat, _ := json.Marshal(j)
dat, err := json.Marshal(j)
if err != nil {
return "", err
}
return NewJSONSchemaConverter(grammarOpts.PropOrder).GrammarFromBytes(dat, options...)
}
9 changes: 6 additions & 3 deletions pkg/functions/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func (f Functions) Select(name string) Functions {
return funcs
}

func jsonString(v interface{}) string {
b, _ := json.Marshal(v)
return string(b)
func jsonString(v interface{}) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(b), nil
}
76 changes: 54 additions & 22 deletions pkg/functions/grammar_json_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
}
}

func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string {
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string {
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) (string, error) {
jLiteral, err := jsonString(literal)
if err != nil {
return "", err
}
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jLiteral, func(match string) string {
return GRAMMAR_LITERAL_ESCAPES[match]
})
return fmt.Sprintf(`"%s"`, escaped)
return fmt.Sprintf(`"%s"`, escaped), nil
}

func (sc *JSONSchemaConverter) addRule(name, rule string) string {
Expand Down Expand Up @@ -140,7 +144,7 @@ func (sc *JSONSchemaConverter) finalizeGrammar(options ...func(*GrammarOption))
return strings.Join(lines, "\n")
}

func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) string {
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
st, existType := schema["type"]
var schemaType string
if existType {
Expand All @@ -159,31 +163,44 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,

if oneOfExists {
for i, altSchema := range oneOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
if err != nil {
return "", err
}
alternatives = append(alternatives, alternative)
}
} else if anyOfExists {
for i, altSchema := range anyOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
alternative, err := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
if err != nil {
return "", err
}
alternatives = append(alternatives, alternative)
}
}

rule := strings.Join(alternatives, " | ")
return sc.addRule(ruleName, rule)
return sc.addRule(ruleName, rule), nil
} else if ref, exists := schema["$ref"].(string); exists {
referencedSchema := sc.resolveReference(ref, rootSchema)
return sc.visit(referencedSchema, name, rootSchema)
} else if constVal, exists := schema["const"]; exists {
return sc.addRule(ruleName, sc.formatLiteral(constVal))
literal, err := sc.formatLiteral((constVal))
if err != nil {
return "", err
}
return sc.addRule(ruleName, literal), nil
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
var enumRules []string
for _, enumVal := range enumVals {
enumRule := sc.formatLiteral(enumVal)
enumRule, err := sc.formatLiteral(enumVal)
if err != nil {
return "", err
}
enumRules = append(enumRules, enumRule)
}
rule := strings.Join(enumRules, " | ")
return sc.addRule(ruleName, rule)
return sc.addRule(ruleName, rule), nil
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
propOrder := sc.propOrder
var propPairs []struct {
Expand Down Expand Up @@ -213,21 +230,30 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,
for i, propPair := range propPairs {
propName := propPair.propName
propSchema := propPair.propSchema
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)

propRuleName, err := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
if err != nil {
return "", err
}
lPropName, err := sc.formatLiteral(propName)
if err != nil {
return "", err
}
if i > 0 {
rule.WriteString(` "," space`)
}

rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName))
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, lPropName, propRuleName))
}

rule.WriteString(` "}" space`)
return sc.addRule(ruleName, rule.String())
return sc.addRule(ruleName, rule.String()), nil
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
itemRuleName, err := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
if err != nil {
return "", err
}
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
return sc.addRule(ruleName, rule)
return sc.addRule(ruleName, rule), nil
} else {
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
if !exists {
Expand All @@ -236,7 +262,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,
if ruleName == "root" {
schemaType = "root"
}
return sc.addRule(schemaType, primitiveRule)
return sc.addRule(schemaType, primitiveRule), nil
}
}
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} {
Expand All @@ -262,14 +288,20 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin
return def
}

func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) string {
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
sc.addRule("freestring", PRIMITIVE_RULES["freestring"])
sc.visit(schema, "", schema)
return sc.finalizeGrammar(options...)
_, err := sc.visit(schema, "", schema)
if err != nil {
return "", err
}
return sc.finalizeGrammar(options...), nil
}

func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) string {
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
var schema map[string]interface{}
_ = json.Unmarshal(b, &schema)
err := json.Unmarshal(b, &schema)
if err != nil {
return "", err
}
return sc.Grammar(schema, options...)
}
41 changes: 26 additions & 15 deletions pkg/functions/grammar_json_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package functions_test
import (
"strings"

"github.com/mudler/LocalAI/pkg/functions"
. "github.com/mudler/LocalAI/pkg/functions"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -235,7 +234,8 @@ root-1-name ::= "\"search\""`
var _ = Describe("JSON schema grammar tests", func() {
Context("JSON", func() {
It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
grammar, err := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
Expect(err).To(BeNil())
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
Expand All @@ -245,7 +245,8 @@ var _ = Describe("JSON schema grammar tests", func() {
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
})
It("generates a valid grammar from JSON schema", func() {
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2))
grammar, err := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2))
Expect(err).To(BeNil())
results := strings.Split(inputResult3, "\n")
for _, r := range results {
if r != "" {
Expand All @@ -259,7 +260,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctions}

grammar := structuredGrammar.Grammar()
grammar, err := structuredGrammar.Grammar()
Expect(err).To(BeNil())
results := strings.Split(inputResult1, "\n")
for _, r := range results {
if r != "" {
Expand All @@ -273,7 +275,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctions}

grammar := structuredGrammar.Grammar(functions.EnableMaybeArray)
grammar, err := structuredGrammar.Grammar(EnableMaybeArray)
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
inputResult2,
Expand All @@ -291,7 +294,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctionsName}

grammar := structuredGrammar.Grammar(functions.EnableMaybeArray)
grammar, err := structuredGrammar.Grammar(EnableMaybeArray)
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
inputResult4,
Expand All @@ -309,10 +313,11 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctionsName}

grammar := structuredGrammar.Grammar(
functions.SetPrefix("suffix"),
functions.EnableMaybeArray,
grammar, err := structuredGrammar.Grammar(
SetPrefix("suffix"),
EnableMaybeArray,
)
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
rootResult(`"suffix" arr | realvalue`),
Expand All @@ -329,7 +334,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctionsName}

grammar := structuredGrammar.Grammar(functions.SetPrefix("suffix"))
grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"))
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
rootResult(`"suffix" realvalue`),
Expand All @@ -346,7 +352,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctionsName}

grammar := structuredGrammar.Grammar(functions.SetPrefix("suffix"), functions.EnableMaybeString)
grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"), EnableMaybeString)
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
rootResult(`( "suffix" realvalue | mixedstring )`),
Expand All @@ -363,7 +370,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctionsName}

grammar := structuredGrammar.Grammar(functions.SetPrefix("suffix"), functions.EnableMaybeString, functions.EnableMaybeArray)
grammar, err := structuredGrammar.Grammar(SetPrefix("suffix"), EnableMaybeString, EnableMaybeArray)
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
rootResult(`( "suffix" (arr | realvalue) | mixedstring )`),
Expand All @@ -382,7 +390,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctionsName}

grammar := structuredGrammar.Grammar(functions.EnableMaybeString, functions.EnableMaybeArray)
grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray)
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
rootResult(`mixedstring | arr | realvalue`),
Expand All @@ -400,7 +409,8 @@ var _ = Describe("JSON schema grammar tests", func() {
structuredGrammar := JSONFunctionStructure{
OneOf: testFunctionsName}

grammar := structuredGrammar.Grammar(functions.EnableMaybeString, functions.EnableMaybeArray, functions.NoMixedFreeString)
grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray, NoMixedFreeString)
Expect(err).To(BeNil())
results := strings.Split(
strings.Join([]string{
rootResult(`freestring | arr | realvalue`),
Expand All @@ -422,7 +432,8 @@ var _ = Describe("JSON schema grammar tests", func() {
realvalue
("," realvalue)*
)? "]"`
grammar := structuredGrammar.Grammar(functions.EnableMaybeString, functions.EnableMaybeArray, functions.DisableParallelNewLines)
grammar, err := structuredGrammar.Grammar(EnableMaybeString, EnableMaybeArray, DisableParallelNewLines)
Expect(err).To(BeNil())
results := strings.Split(content, "\n")
for _, r := range results {
if r != "" {
Expand Down

0 comments on commit 821d78a

Please sign in to comment.