Skip to content

Commit

Permalink
feat(grammar): add llama3.1 schema (#3015)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: Ettore Di Giacinto <[email protected]>

* get rid of panics

Signed-off-by: Ettore Di Giacinto <[email protected]>

* expose it properly from the config

Signed-off-by: Ettore Di Giacinto <[email protected]>

* Simplify

Signed-off-by: Ettore Di Giacinto <[email protected]>

* forgot to commit

Signed-off-by: Ettore Di Giacinto <[email protected]>

* Remove focus on test

Signed-off-by: Ettore Di Giacinto <[email protected]>

* Small fixups

Signed-off-by: Ettore Di Giacinto <[email protected]>

---------

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Jul 26, 2024
1 parent fee5294 commit 2169c34
Show file tree
Hide file tree
Showing 14 changed files with 609 additions and 148 deletions.
4 changes: 2 additions & 2 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

// Update input grammar
jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil {
config.Grammar = g
}
case input.JSONFunctionGrammarObject != nil:
g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarConfig.Options()...)
g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil {
config.Grammar = g
}
Expand Down
26 changes: 22 additions & 4 deletions pkg/functions/function_structure.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package functions

import "encoding/json"
import (
"encoding/json"

"github.com/mudler/LocalAI/pkg/functions/grammars"
)

type Item struct {
Type string `json:"type"`
Expand All @@ -13,13 +17,27 @@ type JSONFunctionStructure struct {
Defs map[string]interface{} `json:"$defs,omitempty"`
}

func (j JSONFunctionStructure) Grammar(options ...func(*GrammarOption)) (string, error) {
grammarOpts := &GrammarOption{}
func (j JSONFunctionStructure) Grammar(options ...func(*grammars.GrammarOption)) (string, error) {
grammarOpts := &grammars.GrammarOption{}
grammarOpts.Apply(options...)

dat, err := json.Marshal(j)
if err != nil {
return "", err
}
return NewJSONSchemaConverter(grammarOpts.PropOrder).GrammarFromBytes(dat, options...)

converter := NewSchemaConverter(*grammarOpts)
return converter.GrammarFromBytes(dat, options...)
}

type SchemaConverter interface {
GrammarFromBytes([]byte, ...func(*grammars.GrammarOption)) (string, error)
}

func NewSchemaConverter(opt grammars.GrammarOption) SchemaConverter {
switch {
case opt.SchemaType == grammars.LLama31Schema:
return grammars.NewLLama31SchemaConverter(opt.FunctionName)
}
return grammars.NewJSONSchemaConverter(opt.PropOrder)
}
8 changes: 0 additions & 8 deletions pkg/functions/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,3 @@ func (f Functions) Select(name string) Functions {

return funcs
}

func jsonString(v interface{}) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(b), nil
}
16 changes: 2 additions & 14 deletions pkg/functions/functions_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,11 @@ package functions_test
import (
"testing"

. "github.com/mudler/LocalAI/pkg/functions"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestGrammar(t *testing.T) {
func TestFunctions(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Grammar test suite")
}

func createFunction(field1 string, field2 string, name string, properties map[string]interface{}) map[string]interface{} {
property := map[string]interface{}{}
property[field1] = FunctionName{Const: name}
property[field2] = Argument{
Type: "object",
Properties: properties,
}
return property
RunSpecs(t, "Functions test suite")
}
15 changes: 13 additions & 2 deletions pkg/functions/bnf_rules.go → pkg/functions/grammars/bnf_rules.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package functions
package grammars

import "regexp"
import (
"encoding/json"
"regexp"
)

var (
PRIMITIVE_RULES = map[string]string{
Expand Down Expand Up @@ -45,3 +48,11 @@ const (
("," realvalue)*
)? "]"`
)

func jsonString(v interface{}) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(b), nil
}
25 changes: 25 additions & 0 deletions pkg/functions/grammars/grammars_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package grammars_test

import (
"testing"

. "github.com/mudler/LocalAI/pkg/functions"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestGrammar(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Grammar test suite")
}

func createFunction(field1 string, field2 string, name string, properties map[string]interface{}) map[string]interface{} {
property := map[string]interface{}{}
property[field1] = FunctionName{Const: name}
property[field2] = Argument{
Type: "object",
Properties: properties,
}
return property
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package functions
package grammars

// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887

Expand All @@ -7,13 +7,11 @@ import (
"fmt"
"sort"
"strings"

"github.com/mudler/LocalAI/pkg/utils"
)

type JSONSchemaConverter struct {
propOrder map[string]int
rules map[string]string
rules Rules
}

func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
Expand Down Expand Up @@ -60,90 +58,6 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string {
return key
}

func (sc *JSONSchemaConverter) finalizeGrammar(options ...func(*GrammarOption)) string {

grammarOpts := &GrammarOption{}
grammarOpts.Apply(options...)

prefix := grammarOpts.Prefix
maybeArray := grammarOpts.MaybeArray
disableParallelNewLines := grammarOpts.DisableParallelNewLines
maybeString := grammarOpts.MaybeString
noMixedFreeString := grammarOpts.NoMixedFreeString

var lines []string

swapRoot := maybeArray || maybeString || prefix != ""

// write down the computed rules.
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule
for name, rule := range sc.rules {
if swapRoot && name == "root" {
name = "realvalue"
}
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
}

if !swapRoot {
return strings.Join(lines, "\n")
}

newRoot := "realvalue"
if maybeArray {
newRoot = "arr | realvalue"
}

freestringRule := "mixedstring"
if noMixedFreeString {
freestringRule = "freestring"
}

if prefix != "" {
// quote newlines in suffix
prefix = utils.EscapeNewLines(prefix)

if maybeArray && maybeString {
newRoot = "(" + newRoot + ")"
}

if maybeString {
//newRoot = "( (\"" + suffix + "\" " + newRoot + ") | freestring ) "
newRoot = "( \"" + prefix + "\" " + newRoot + " | " + freestringRule + " ) "
} else {
newRoot = "\"" + prefix + "\" " + "" + newRoot + ""
}
} else if maybeString {
if maybeArray {
// newRoot = "(" + newRoot + ")"
}

newRoot = freestringRule + " | " + newRoot
}

lines = append(lines, fmt.Sprintf("%s ::= %s", "root", newRoot))
if disableParallelNewLines {
lines = append(lines, array)
} else {
lines = append(lines, arrayNewLines)
}

if maybeArray {
if grammarOpts.ExpectStringsAfterJSON {
lines = append(lines, `mixedstring ::= freestring | freestring arr freestring | (freestring realvalue freestring)* | realvalue | arr`)
} else {
lines = append(lines, `mixedstring ::= freestring | freestring arr | freestring realvalue | realvalue | arr`)
}
} else {
if grammarOpts.ExpectStringsAfterJSON {
lines = append(lines, `mixedstring ::= freestring | (freestring realvalue freestring)* | realvalue`)
} else {
lines = append(lines, `mixedstring ::= freestring | freestring realvalue | realvalue`)
}
}

return strings.Join(lines, "\n")
}

func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) (string, error) {
st, existType := schema["type"]
var schemaType string
Expand Down Expand Up @@ -182,7 +96,10 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,
rule := strings.Join(alternatives, " | ")
return sc.addRule(ruleName, rule), nil
} else if ref, exists := schema["$ref"].(string); exists {
referencedSchema := sc.resolveReference(ref, rootSchema)
referencedSchema, err := sc.resolveReference(ref, rootSchema)
if err != nil {
return "", err
}
return sc.visit(referencedSchema, name, rootSchema)
} else if constVal, exists := schema["const"]; exists {
literal, err := sc.formatLiteral((constVal))
Expand Down Expand Up @@ -257,35 +174,31 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string,
} else {
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
if !exists {
panic(fmt.Sprintf("Unrecognized schema: %v", schema))
return "", fmt.Errorf("unrecognized schema: %v", schema)
}
if ruleName == "root" {
schemaType = "root"
}
return sc.addRule(schemaType, primitiveRule), nil
}
}
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} {
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) (map[string]interface{}, error) {
if !strings.HasPrefix(ref, "#/$defs/") {
panic(fmt.Sprintf("Invalid reference format: %s", ref))
return nil, fmt.Errorf("invalid reference format: %s", ref)
}

defKey := strings.TrimPrefix(ref, "#/$defs/")
definitions, exists := rootSchema["$defs"].(map[string]interface{})
if !exists {
fmt.Println(rootSchema)

panic("No definitions found in the schema")
return nil, fmt.Errorf("no definitions found in the schema: %s", rootSchema)
}

def, exists := definitions[defKey].(map[string]interface{})
if !exists {
fmt.Println(definitions)

panic(fmt.Sprintf("Definition not found: %s", defKey))
return nil, fmt.Errorf("definition not found: %s %+v", defKey, definitions)
}

return def
return def, nil
}

func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ...func(*GrammarOption)) (string, error) {
Expand All @@ -294,7 +207,7 @@ func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, options ..
if err != nil {
return "", err
}
return sc.finalizeGrammar(options...), nil
return sc.rules.ToGrammar(options...), nil
}

func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, options ...func(*GrammarOption)) (string, error) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package functions_test
package grammars_test

import (
"strings"

. "github.com/mudler/LocalAI/pkg/functions"
. "github.com/mudler/LocalAI/pkg/functions/grammars"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
Expand Down
Loading

0 comments on commit 2169c34

Please sign in to comment.