Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly use context in plugin and provide alternative _WithContext methods #62

Merged
merged 7 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ type Runtime struct {
Wazero wazero.Runtime
Extism api.Module
Env api.Module
ctx context.Context
hasWasi bool
}

Expand Down Expand Up @@ -302,7 +301,12 @@ func (m *Manifest) UnmarshalJSON(data []byte) error {

// Close closes the plugin by freeing the underlying resources.
func (p *Plugin) Close() error {
return p.Runtime.Wazero.Close(p.Runtime.ctx)
return p.CloseWithContext(context.Background())
}

// Close closes the plugin by freeing the underlying resources.
Marton6 marked this conversation as resolved.
Show resolved Hide resolved
func (p *Plugin) CloseWithContext(ctx context.Context) error {
return p.Runtime.Wazero.Close(ctx)
}

// NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions.
Expand Down Expand Up @@ -351,17 +355,16 @@ func NewPlugin(
Wazero: rt,
Extism: extism,
Env: env,
ctx: ctx,
}

if config.EnableWasi {
wasi_snapshot_preview1.MustInstantiate(c.ctx, c.Wazero)
wasi_snapshot_preview1.MustInstantiate(ctx, c.Wazero)

c.hasWasi = true
}

for name, funcs := range hostModules {
_, err := buildHostModule(c.ctx, c.Wazero, name, funcs)
_, err := buildHostModule(ctx, c.Wazero, name, funcs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -429,7 +432,7 @@ func NewPlugin(
}
}

m, err := c.Wazero.InstantiateWithConfig(c.ctx, data.Data, moduleConfig.WithName(data.Name))
m, err := c.Wazero.InstantiateWithConfig(ctx, data.Data, moduleConfig.WithName(data.Name))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -470,7 +473,7 @@ func NewPlugin(
logLevel: logLevel,
}

p.guestRuntime = detectGuestRuntime(p)
p.guestRuntime = detectGuestRuntime(ctx, p)
return p, nil
}

Expand All @@ -482,29 +485,39 @@ func NewPlugin(

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
func (plugin *Plugin) SetInput(data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(plugin.Runtime.ctx)
return plugin.SetInputWithContext(context.Background(), data)
}

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
Marton6 marked this conversation as resolved.
Show resolved Hide resolved
func (plugin *Plugin) SetInputWithContext(ctx context.Context, data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(ctx)
if err != nil {
fmt.Println(err)
return 0, errors.New("reset")
}

ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(plugin.Runtime.ctx, uint64(len(data)))
ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(len(data)))
if err != nil {
return 0, err
}
plugin.Memory().Write(uint32(ptr[0]), data)
plugin.Runtime.Extism.ExportedFunction("input_set").Call(plugin.Runtime.ctx, ptr[0], uint64(len(data)))
plugin.Runtime.Extism.ExportedFunction("input_set").Call(ctx, ptr[0], uint64(len(data)))
return ptr[0], nil
}

// GetOutput retrieves the output data from the last WebAssembly function call.
func (plugin *Plugin) GetOutput() ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(plugin.Runtime.ctx)
return plugin.GetOutputWithContext(context.Background())
}

// GetOutput retrieves the output data from the last WebAssembly function call.
Marton6 marked this conversation as resolved.
Show resolved Hide resolved
func (plugin *Plugin) GetOutputWithContext(ctx context.Context) ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(ctx)
if err != nil {
return []byte{}, err
}

outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(plugin.Runtime.ctx)
outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(ctx)
if err != nil {
return []byte{}, err
}
Expand All @@ -524,7 +537,12 @@ func (plugin *Plugin) Memory() api.Memory {

// GetError retrieves the error message from the last WebAssembly function call, if any.
func (plugin *Plugin) GetError() string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(plugin.Runtime.ctx)
return plugin.GetErrorWithContext(context.Background())
}

// GetError retrieves the error message from the last WebAssembly function call, if any.
Marton6 marked this conversation as resolved.
Show resolved Hide resolved
func (plugin *Plugin) GetErrorWithContext(ctx context.Context) string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(ctx)
if err != nil {
return ""
}
Expand All @@ -533,7 +551,7 @@ func (plugin *Plugin) GetError() string {
return ""
}

errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(plugin.Runtime.ctx, errOffs[0])
errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, errOffs[0])
if err != nil {
return ""
}
Expand All @@ -549,7 +567,7 @@ func (plugin *Plugin) FunctionExists(name string) bool {

// Call a function by name with the given input, returning the output
func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) {
return plugin.CallWithContext(plugin.Runtime.ctx, name, data)
return plugin.CallWithContext(context.Background(), name, data)
}

// Call a function by name with the given input and context, returning the output
Expand Down
5 changes: 3 additions & 2 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func TestCancel(t *testing.T) {
manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "3" // sleep for 3 seconds

ctx, cancel := context.WithCancel(context.Background())
ctx := context.Background()
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
Expand All @@ -533,12 +533,13 @@ func TestCancel(t *testing.T) {

defer plugin.Close()

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()

exit, _, err := plugin.Call("run_test", []byte{})
exit, _, err := plugin.CallWithContext(ctx, "run_test", []byte{})

assert.Equal(t, sys.ExitCodeContextCanceled, exit, "Exit code must be `sys.ExitCodeContextCanceled`")
assert.Equal(t, "module closed with context canceled", err.Error())
Expand Down
21 changes: 18 additions & 3 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ func (p *CurrentPlugin) Memory() api.Memory {

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(p.plugin.Runtime.ctx, uint64(n))
return p.AllocWithContext(context.Background(), n)
}

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(n))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand All @@ -135,7 +140,12 @@ func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {

// Free the memory block specified by the given offset
func (p *CurrentPlugin) Free(offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(p.plugin.Runtime.ctx, uint64(offset))
return p.FreeWithContext(context.Background(), offset)
}

// Free the memory block specified by the given offset
func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(ctx, uint64(offset))
if err != nil {
return err
}
Expand All @@ -145,7 +155,12 @@ func (p *CurrentPlugin) Free(offset uint64) error {

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) Length(offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(p.plugin.Runtime.ctx, uint64(offs))
return p.LengthWithContext(context.Background(), offs)
}

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, uint64(offs))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand Down
32 changes: 17 additions & 15 deletions runtime.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package extism

import (
"context"

"github.com/tetratelabs/wazero/api"
)

Expand All @@ -20,15 +22,15 @@ type guestRuntime struct {
initialized bool
}

func detectGuestRuntime(p *Plugin) guestRuntime {
func detectGuestRuntime(ctx context.Context, p *Plugin) guestRuntime {
m := p.Main

runtime, ok := haskellRuntime(p, m)
runtime, ok := haskellRuntime(ctx, p, m)
if ok {
return runtime
}

runtime, ok = wasiRuntime(p, m)
runtime, ok = wasiRuntime(ctx, p, m)
if ok {
return runtime
}
Expand All @@ -40,7 +42,7 @@ func detectGuestRuntime(p *Plugin) guestRuntime {
// Check for Haskell runtime initialization functions
// Initialize Haskell runtime if `hs_init` and `hs_exit` are present,
// by calling the `hs_init` export
func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
func haskellRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) {
initFunc := m.ExportedFunction("hs_init")
if initFunc == nil {
return guestRuntime{}, false
Expand All @@ -56,12 +58,12 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {

init := func() error {
if reactorInit != nil {
_, err := reactorInit.Call(p.Runtime.ctx)
_, err := reactorInit.Call(ctx)
if err != nil {
p.Logf(LogLevelError, "Error running reactor _initialize: %s", err.Error())
}
}
_, err := initFunc.Call(p.Runtime.ctx, 0, 0)
_, err := initFunc.Call(ctx, 0, 0)
if err == nil {
p.Log(LogLevelDebug, "Initialized Haskell language runtime.")
}
Expand All @@ -74,24 +76,24 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
}

// Check for initialization functions defined by the WASI standard
func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) {
func wasiRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) {
if !p.Runtime.hasWasi {
return guestRuntime{}, false
}

// WASI supports two modules: Reactors and Commands
// we prioritize Reactors over Commands
// see: https://github.com/WebAssembly/WASI/blob/main/legacy/application-abi.md
if r, ok := reactorModule(m, p); ok {
if r, ok := reactorModule(ctx, m, p); ok {
return r, ok
}

return commandModule(m, p)
return commandModule(ctx, m, p)
}

// Check for `_initialize` this is used by WASI to initialize certain interfaces.
func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "_initialize")
func reactorModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(ctx, m, p, "_initialize")
if init == nil {
return guestRuntime{}, false
}
Expand All @@ -104,8 +106,8 @@ func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) {

// Check for `__wasm__call_ctors`, this is used by WASI to
// initialize certain interfaces.
func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(m, p, "__wasm_call_ctors")
func commandModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) {
init := findFunc(ctx, m, p, "__wasm_call_ctors")
if init == nil {
return guestRuntime{}, false
}
Expand All @@ -116,7 +118,7 @@ func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) {
return guestRuntime{runtimeType: Wasi, init: init}, true
}

func findFunc(m api.Module, p *Plugin, name string) func() error {
func findFunc(ctx context.Context, m api.Module, p *Plugin, name string) func() error {
initFunc := m.ExportedFunction(name)
if initFunc == nil {
return nil
Expand All @@ -130,7 +132,7 @@ func findFunc(m api.Module, p *Plugin, name string) func() error {

return func() error {
p.Logf(LogLevelDebug, "Calling %v", name)
_, err := initFunc.Call(p.Runtime.ctx)
_, err := initFunc.Call(ctx)
return err
Marton6 marked this conversation as resolved.
Show resolved Hide resolved
}
}
Expand Down
Loading