Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly use context in plugin and provide alternative _WithContext methods #62

Merged
merged 7 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ type Runtime struct {
Wazero wazero.Runtime
Extism api.Module
Env api.Module
ctx context.Context
hasWasi bool
}

Expand Down Expand Up @@ -302,7 +301,12 @@ func (m *Manifest) UnmarshalJSON(data []byte) error {

// Close closes the plugin by freeing the underlying resources.
func (p *Plugin) Close() error {
return p.Runtime.Wazero.Close(p.Runtime.ctx)
return p.CloseWithContext(context.Background())
}

// CloseWithContext closes the plugin by freeing the underlying resources.
func (p *Plugin) CloseWithContext(ctx context.Context) error {
return p.Runtime.Wazero.Close(ctx)
}

// NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions.
Expand Down Expand Up @@ -351,17 +355,16 @@ func NewPlugin(
Wazero: rt,
Extism: extism,
Env: env,
ctx: ctx,
}

if config.EnableWasi {
wasi_snapshot_preview1.MustInstantiate(c.ctx, c.Wazero)
wasi_snapshot_preview1.MustInstantiate(ctx, c.Wazero)

c.hasWasi = true
}

for name, funcs := range hostModules {
_, err := buildHostModule(c.ctx, c.Wazero, name, funcs)
_, err := buildHostModule(ctx, c.Wazero, name, funcs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -429,7 +432,7 @@ func NewPlugin(
}
}

m, err := c.Wazero.InstantiateWithConfig(c.ctx, data.Data, moduleConfig.WithName(data.Name))
m, err := c.Wazero.InstantiateWithConfig(ctx, data.Data, moduleConfig.WithName(data.Name))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -470,7 +473,7 @@ func NewPlugin(
logLevel: logLevel,
}

p.guestRuntime = detectGuestRuntime(p)
p.guestRuntime = detectGuestRuntime(ctx, p)
return p, nil
}

Expand All @@ -482,29 +485,39 @@ func NewPlugin(

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
func (plugin *Plugin) SetInput(data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(plugin.Runtime.ctx)
return plugin.SetInputWithContext(context.Background(), data)
}

// SetInputWithContext sets the input data for the plugin to be used in the next WebAssembly function call.
func (plugin *Plugin) SetInputWithContext(ctx context.Context, data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(ctx)
if err != nil {
fmt.Println(err)
return 0, errors.New("reset")
}

ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(plugin.Runtime.ctx, uint64(len(data)))
ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(len(data)))
if err != nil {
return 0, err
}
plugin.Memory().Write(uint32(ptr[0]), data)
plugin.Runtime.Extism.ExportedFunction("input_set").Call(plugin.Runtime.ctx, ptr[0], uint64(len(data)))
plugin.Runtime.Extism.ExportedFunction("input_set").Call(ctx, ptr[0], uint64(len(data)))
return ptr[0], nil
}

// GetOutput retrieves the output data from the last WebAssembly function call.
func (plugin *Plugin) GetOutput() ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(plugin.Runtime.ctx)
return plugin.GetOutputWithContext(context.Background())
}

// GetOutputWithContext retrieves the output data from the last WebAssembly function call.
func (plugin *Plugin) GetOutputWithContext(ctx context.Context) ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(ctx)
if err != nil {
return []byte{}, err
}

outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(plugin.Runtime.ctx)
outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(ctx)
if err != nil {
return []byte{}, err
}
Expand All @@ -524,7 +537,12 @@ func (plugin *Plugin) Memory() api.Memory {

// GetError retrieves the error message from the last WebAssembly function call, if any.
func (plugin *Plugin) GetError() string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(plugin.Runtime.ctx)
return plugin.GetErrorWithContext(context.Background())
}

// GetErrorWithContext retrieves the error message from the last WebAssembly function call.
func (plugin *Plugin) GetErrorWithContext(ctx context.Context) string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(ctx)
if err != nil {
return ""
}
Expand All @@ -533,7 +551,7 @@ func (plugin *Plugin) GetError() string {
return ""
}

errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(plugin.Runtime.ctx, errOffs[0])
errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, errOffs[0])
if err != nil {
return ""
}
Expand All @@ -549,7 +567,7 @@ func (plugin *Plugin) FunctionExists(name string) bool {

// Call a function by name with the given input, returning the output
func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) {
return plugin.CallWithContext(plugin.Runtime.ctx, name, data)
return plugin.CallWithContext(context.Background(), name, data)
}

// Call a function by name with the given input and context, returning the output
Expand Down Expand Up @@ -579,7 +597,7 @@ func (plugin *Plugin) CallWithContext(ctx context.Context, name string, data []b

var isStart = name == "_start"
if plugin.guestRuntime.init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.init()
err := plugin.guestRuntime.init(ctx)
if err != nil {
return 1, []byte{}, errors.New(fmt.Sprintf("failed to initialize runtime: %v", err))
}
Expand Down
71 changes: 69 additions & 2 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/experimental/logging"
"github.com/tetratelabs/wazero/sys"
)

Expand Down Expand Up @@ -518,7 +520,7 @@ func TestCancel(t *testing.T) {
manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "3" // sleep for 3 seconds

ctx, cancel := context.WithCancel(context.Background())
ctx := context.Background()
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
Expand All @@ -533,12 +535,13 @@ func TestCancel(t *testing.T) {

defer plugin.Close()

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()

exit, _, err := plugin.Call("run_test", []byte{})
exit, _, err := plugin.CallWithContext(ctx, "run_test", []byte{})

assert.Equal(t, sys.ExitCodeContextCanceled, exit, "Exit code must be `sys.ExitCodeContextCanceled`")
assert.Equal(t, "module closed with context canceled", err.Error())
Expand Down Expand Up @@ -734,6 +737,70 @@ func TestInputOffset(t *testing.T) {
}
}

// make sure cancelling the context given to NewPlugin doesn't affect plugin calls
func TestContextCancel(t *testing.T) {
manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "0" // sleep for 0 seconds

ctx, cancel := context.WithCancel(context.Background())
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
RuntimeConfig: wazero.NewRuntimeConfig().WithCloseOnContextDone(true),
}

plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{})

if err != nil {
t.Errorf("Could not create plugin: %v", err)
}

defer plugin.Close()
cancel() // cancel the parent context

exit, out, err := plugin.CallWithContext(context.Background(), "run_test", []byte{})

if assertCall(t, err, exit) {
assert.Equal(t, "slept for 0 seconds", string(out))
}
}

// make sure we can still turn on experimental wazero features
func TestEnableExperimentalFeature(t *testing.T) {
var buf bytes.Buffer

// Set context to one that has an experimental listener
ctx := context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&buf))

manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "0" // sleep for 0 seconds

config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
RuntimeConfig: wazero.NewRuntimeConfig().WithCloseOnContextDone(true),
}

plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{})

if err != nil {
t.Errorf("Could not create plugin: %v", err)
}

defer plugin.Close()

var buf2 bytes.Buffer
ctx = context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&buf2))
exit, out, err := plugin.CallWithContext(ctx, "run_test", []byte{})

if assertCall(t, err, exit) {
assert.Equal(t, "slept for 0 seconds", string(out))

assert.NotEmpty(t, buf.String())
assert.Empty(t, buf2.String())
}
}

func BenchmarkInitialize(b *testing.B) {
ctx := context.Background()
cache := wazero.NewCompilationCache()
Expand Down
21 changes: 18 additions & 3 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ func (p *CurrentPlugin) Memory() api.Memory {

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(p.plugin.Runtime.ctx, uint64(n))
return p.AllocWithContext(context.Background(), n)
}

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(n))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand All @@ -135,7 +140,12 @@ func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {

// Free the memory block specified by the given offset
func (p *CurrentPlugin) Free(offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(p.plugin.Runtime.ctx, uint64(offset))
return p.FreeWithContext(context.Background(), offset)
}

// Free the memory block specified by the given offset
func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(ctx, uint64(offset))
if err != nil {
return err
}
Expand All @@ -145,7 +155,12 @@ func (p *CurrentPlugin) Free(offset uint64) error {

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) Length(offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(p.plugin.Runtime.ctx, uint64(offs))
return p.LengthWithContext(context.Background(), offs)
}

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, uint64(offs))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand Down
Loading
Loading