Skip to content

Commit

Permalink
feat!: provide plugins get_log_level runtime function & support levels (
Browse files Browse the repository at this point in the history
#74)

This PR aligns the extism host env with Rust, as per
extism/extism#758

NOTE of breaking change:

I've removed the ability to set the log level on the `*Plugin` itself.
To align with Rust SDK, this is only an SDK-global setting. You can
still add a custom logger to each plugin function. The SDK won't print
any logs unless there is a log function set for each plugin. The global
setting is in line with how libraries are expected to respect the main
application's log level.
  • Loading branch information
nilslice authored Sep 16, 2024
1 parent 39b5924 commit 24b5f36
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 90 deletions.
101 changes: 61 additions & 40 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (
"fmt"
"io"
"log"
"math"
"net/http"
"os"
"strings"
"sync/atomic"
"time"

observe "github.com/dylibso/observe-sdk/go"
Expand All @@ -27,6 +29,9 @@ type module struct {
wasm []byte
}

type PluginCtxKey string
type InputOffsetKey string

//go:embed extism-runtime.wasm
var extismRuntimeWasm []byte

Expand Down Expand Up @@ -63,31 +68,51 @@ type HttpRequest struct {
}

// LogLevel defines different log levels.
type LogLevel uint8
type LogLevel int32

const (
logLevelUnset LogLevel = iota // unexporting this intentionally so its only ever the default
LogLevelOff
LogLevelError
LogLevelWarn
LogLevelInfo
LogLevelDebug
LogLevelTrace
LogLevelDebug
LogLevelInfo
LogLevelWarn
LogLevelError

LogLevelOff LogLevel = math.MaxInt32
)

func (l LogLevel) ExtismCompat() int32 {
switch l {
case LogLevelTrace:
return 0
case LogLevelDebug:
return 1
case LogLevelInfo:
return 2
case LogLevelWarn:
return 3
case LogLevelError:
return 4
default:
return int32(LogLevelOff)
}
}

func (l LogLevel) String() string {
s := ""
switch l {
case LogLevelError:
s = "ERROR"
case LogLevelWarn:
s = "WARN"
case LogLevelInfo:
s = "INFO"
case LogLevelDebug:
s = "DEBUG"
case LogLevelTrace:
s = "TRACE"
case LogLevelDebug:
s = "DEBUG"
case LogLevelInfo:
s = "INFO"
case LogLevelWarn:
s = "WARN"
case LogLevelError:
s = "ERROR"
default:
s = "OFF"
}
return s
}
Expand All @@ -107,7 +132,6 @@ type Plugin struct {
MaxHttpResponseBytes int64
MaxVarBytes int64
log func(LogLevel, string)
logLevel LogLevel
guestRuntime guestRuntime
Adapter *observe.AdapterBase
TraceCtx *observe.TraceCtx
Expand All @@ -122,13 +146,8 @@ func (p *Plugin) SetLogger(logger func(LogLevel, string)) {
p.log = logger
}

// SetLogLevel sets the minim logging level, applies to custom logging callbacks too
func (p *Plugin) SetLogLevel(level LogLevel) {
p.logLevel = level
}

func (p *Plugin) Log(level LogLevel, message string) {
if level > p.logLevel {
if level < LogLevel(pluginLogLevel.Load()) {
return
}

Expand Down Expand Up @@ -311,7 +330,7 @@ func (m *Manifest) UnmarshalJSON(data []byte) error {
Name: w.Name,
})
} else {
return errors.New("Invalid Wasm entry")
return errors.New("invalid Wasm entry")
}
}
return nil
Expand All @@ -327,6 +346,14 @@ func (p *Plugin) CloseWithContext(ctx context.Context) error {
return p.Runtime.Wazero.Close(ctx)
}

// add an atomic global to store the plugin runtime-wide log level
var pluginLogLevel = atomic.Int32{}

// SetPluginLogLevel sets the log level for the plugin
func SetLogLevel(level LogLevel) {
pluginLogLevel.Store(int32(level.ExtismCompat()))
}

// NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions.
// The returned plugin can be used to call WebAssembly functions and interact with the plugin.
func NewPlugin(
Expand Down Expand Up @@ -390,7 +417,7 @@ func NewPlugin(

count := len(manifest.Wasm)
if count == 0 {
return nil, fmt.Errorf("Manifest can't be empty.")
return nil, fmt.Errorf("manifest can't be empty")
}

modules := map[string]module{}
Expand Down Expand Up @@ -444,7 +471,7 @@ func NewPlugin(
if data.Name == "main" && config.ObserveAdapter != nil {
trace, err = config.ObserveAdapter.NewTraceCtx(ctx, c.Wazero, data.Data, config.ObserveOptions)
if err != nil {
return nil, fmt.Errorf("Failed to initialize Observe Adapter: %v", err)
return nil, fmt.Errorf("failed to initialize Observe Adapter: %v", err)
}

trace.Finish()
Expand All @@ -454,13 +481,13 @@ func NewPlugin(
_, okm := modules[data.Name]

if data.Name == "extism:host/env" || okh || okm {
return nil, fmt.Errorf("Module name collision: '%s'", data.Name)
return nil, fmt.Errorf("module name collision: '%s'", data.Name)
}

if data.Hash != "" {
calculatedHash := calculateHash(data.Data)
if data.Hash != calculatedHash {
return nil, fmt.Errorf("Hash mismatch for module '%s'", data.Name)
return nil, fmt.Errorf("hash mismatch for module '%s'", data.Name)
}
}

Expand All @@ -472,11 +499,6 @@ func NewPlugin(
modules[data.Name] = module{module: m, wasm: data.Data}
}

logLevel := LogLevelWarn
if config.LogLevel != logLevelUnset {
logLevel = config.LogLevel
}

i := 0
httpMax := int64(1024 * 1024 * 50)
if manifest.Memory != nil && manifest.Memory.MaxHttpResponseBytes >= 0 {
Expand All @@ -502,7 +524,6 @@ func NewPlugin(
MaxHttpResponseBytes: httpMax,
MaxVarBytes: varMax,
log: logStd,
logLevel: logLevel,
Adapter: config.ObserveAdapter,
TraceCtx: trace,
}
Expand All @@ -514,7 +535,7 @@ func NewPlugin(
i++
}

return nil, errors.New("No main module found")
return nil, errors.New("no main module found")
}

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
Expand Down Expand Up @@ -612,28 +633,28 @@ func (plugin *Plugin) CallWithContext(ctx context.Context, name string, data []b
defer cancel()
}

ctx = context.WithValue(ctx, "plugin", plugin)
ctx = context.WithValue(ctx, PluginCtxKey("plugin"), plugin)

intputOffset, err := plugin.SetInput(data)
if err != nil {
return 1, []byte{}, err
}

ctx = context.WithValue(ctx, "inputOffset", intputOffset)
ctx = context.WithValue(ctx, InputOffsetKey("inputOffset"), intputOffset)

var f = plugin.Main.module.ExportedFunction(name)

if f == nil {
return 1, []byte{}, errors.New(fmt.Sprintf("Unknown function: %s", name))
return 1, []byte{}, fmt.Errorf("unknown function: %s", name)
} else if n := len(f.Definition().ResultTypes()); n > 1 {
return 1, []byte{}, errors.New(fmt.Sprintf("Function %s has %v results, expected 0 or 1", name, n))
return 1, []byte{}, fmt.Errorf("function %s has %v results, expected 0 or 1", name, n)
}

var isStart = name == "_start"
if plugin.guestRuntime.init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.init(ctx)
if err != nil {
return 1, []byte{}, errors.New(fmt.Sprintf("failed to initialize runtime: %v", err))
return 1, []byte{}, fmt.Errorf("failed to initialize runtime: %v", err)
}
plugin.guestRuntime.initialized = true
}
Expand Down Expand Up @@ -678,14 +699,14 @@ func (plugin *Plugin) CallWithContext(ctx context.Context, name string, data []b
if rc != 0 {
errMsg := plugin.GetError()
if errMsg == "" {
errMsg = "Encountered an unknown error in call to Extism plugin function " + name
errMsg = "encountered an unknown error in call to Extism plugin function " + name
}
return rc, []byte{}, errors.New(errMsg)
}

output, err := plugin.GetOutput()
if err != nil {
return rc, []byte{}, fmt.Errorf("Failed to get output: %v", err)
return rc, []byte{}, fmt.Errorf("failed to get output: %v", err)
}

return rc, output, nil
Expand Down
71 changes: 49 additions & 22 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"encoding/json"
"fmt"
"log"
"math/rand"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -443,13 +443,16 @@ func TestLog_default(t *testing.T) {
if plugin, ok := plugin(t, manifest); ok {
defer plugin.Close()

SetLogLevel(LogLevelWarn) // Only warn and error logs should be printed to the console
exit, _, err := plugin.Call("run_test", []byte{})

if assertCall(t, err, exit) {
logs := buf.String()

assert.Contains(t, logs, "this is a warning log")
assert.Contains(t, logs, "this is an error log")
assert.NotContains(t, logs, "this is a trace log")
assert.NotContains(t, logs, "this is a debug log")
assert.NotContains(t, logs, "this is an info log")
}
}
}
Expand All @@ -465,34 +468,68 @@ func TestLog_custom(t *testing.T) {
if plugin, ok := plugin(t, manifest); ok {
defer plugin.Close()

var actual []LogEntry
var actual strings.Builder

var fmtLogMessage = func(level LogLevel, message string) string {
return fmt.Sprintf("%s: %s\n", level.String(), message)
}

plugin.SetLogger(func(level LogLevel, message string) {
actual = append(actual, LogEntry{message: message, level: level})
actual.WriteString(fmtLogMessage(level, message))
switch level {
case LogLevelDebug:
assert.Equal(t, level.String(), "DEBUG")
case LogLevelInfo:
assert.Equal(t, fmt.Sprintf("%s", level), "INFO")
assert.Equal(t, level.String(), "INFO")
case LogLevelWarn:
assert.Equal(t, fmt.Sprintf("%s", level), "WARN")
assert.Equal(t, level.String(), "WARN")
case LogLevelError:
assert.Equal(t, fmt.Sprintf("%s", level), "ERROR")
assert.Equal(t, level.String(), "ERROR")
case LogLevelTrace:
assert.Equal(t, fmt.Sprintf("%s", level), "TRACE")
assert.Equal(t, level.String(), "TRACE")
}
})

plugin.SetLogLevel(LogLevelInfo)
SetLogLevel(LogLevelTrace)

exit, _, err := plugin.Call("run_test", []byte{})

if assertCall(t, err, exit) {
expected := []LogEntry{
{message: "this is a trace log", level: LogLevelTrace},
{message: "this is a debug log", level: LogLevelDebug},
{message: "this is an info log", level: LogLevelInfo},
{message: "this is a warning log", level: LogLevelWarn},
{message: "this is an error log", level: LogLevelError},
{message: "this is a trace log", level: LogLevelTrace}}
}
actualLogs := actual.String()
for _, log := range expected {
assert.Contains(t, actualLogs, fmtLogMessage(log.level, log.message))
}
}

assert.Equal(t, expected, actual)
SetLogLevel(LogLevelWarn)
actual.Reset()

exit, _, err = plugin.Call("run_test", []byte{})

if assertCall(t, err, exit) {
expected := []LogEntry{
{message: "this is a warning log", level: LogLevelWarn},
{message: "this is an error log", level: LogLevelError},
}
expectedNot := []LogEntry{
{message: "this is a trace log", level: LogLevelTrace},
{message: "this is a debug log", level: LogLevelDebug},
{message: "this is an info log", level: LogLevelInfo},
}
actualLogs := actual.String()
for _, log := range expected {
assert.Contains(t, actualLogs, fmtLogMessage(log.level, log.message))
}
for _, log := range expectedNot {
assert.NotContains(t, actualLogs, fmtLogMessage(log.level, log.message))
}
}
}
}
Expand Down Expand Up @@ -699,7 +736,7 @@ func TestHelloHaskell(t *testing.T) {
if plugin, ok := plugin(t, manifest); ok {
defer plugin.Close()

plugin.SetLogLevel(LogLevelTrace)
SetLogLevel(LogLevelTrace)
plugin.Config["greeting"] = "Howdy"

exit, output, err := plugin.Call("testing", []byte("John"))
Expand Down Expand Up @@ -1068,16 +1105,6 @@ func BenchmarkReplace(b *testing.B) {
}
}

func generateRandomString(length int, seed int64) string {
rand.Seed(seed)
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
result := make([]byte, length)
for i := range result {
result[i] = charset[rand.Intn(len(charset))]
}
return string(result)
}

func wasiPluginConfig() PluginConfig {
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
Expand Down
Loading

0 comments on commit 24b5f36

Please sign in to comment.