Skip to content

Commit

Permalink
fix(std:runtime_funcs): accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
emil14 committed Sep 21, 2024
1 parent 4c8d743 commit 536ec23
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 25 deletions.
4 changes: 2 additions & 2 deletions examples/reduce_list/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func Test(t *testing.T) {
require.NoError(t, err)
defer os.Chdir(wd)

for i := 0; i < 100; i++ {
cmd := exec.Command("neva", "run", "filter_list")
for i := 0; i < 1; i++ {
cmd := exec.Command("neva", "run", "reduce_list")

// Set a timeout for the command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
Expand Down
10 changes: 6 additions & 4 deletions examples/reduce_list/main.neva
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
const lst list<int> = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

flow Main(start) (stop) {
Iter, Reduce{AddReducer}, Println
Iter<int>
Reduce<int, int>{AddReducer<int>}
Println
---
:start -> (
$lst -> iter -> reduce -> println -> :stop
)
:start -> ($lst -> iter -> reduce:data)
0 -> reduce:init
reduce -> println -> :stop
}
92 changes: 92 additions & 0 deletions internal/runtime/funcs/accumulator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package funcs

import (
"context"
"sync"

"github.com/nevalang/neva/internal/runtime"
)

type accumulator struct{}

func (a accumulator) Create(io runtime.IO, _ runtime.Msg) (func(ctx context.Context), error) {
initIn, err := io.In.Single("init")
if err != nil {
return nil, err
}

updIn, err := io.In.Single("upd")
if err != nil {
return nil, err
}

lastIn, err := io.In.Single("last")
if err != nil {
return nil, err
}

curOut, err := io.Out.Single("cur")
if err != nil {
return nil, err
}

resOut, err := io.Out.Single("res")
if err != nil {
return nil, err
}

return func(ctx context.Context) {
for {
var (
acc runtime.Msg
last = false
)

initMsg, initOk := initIn.Receive(ctx)
if !initOk {
return
}

if !curOut.Send(ctx, initMsg) {
return
}

acc = initMsg

for !last {
var dataMsg, lastMsg runtime.Msg
var dataOk, lastOk bool

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
dataMsg, dataOk = updIn.Receive(ctx)
}()

go func() {
defer wg.Done()
lastMsg, lastOk = lastIn.Receive(ctx)
}()

wg.Wait()

if !dataOk || !lastOk {
return
}

if !curOut.Send(ctx, dataMsg) {
return
}

acc = dataMsg
last = lastMsg.Bool()
}

if !resOut.Send(ctx, acc) {
return
}
}
}, nil
}
25 changes: 19 additions & 6 deletions internal/runtime/funcs/cond.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package funcs

import (
"context"
"sync"

"github.com/nevalang/neva/internal/runtime"
)
Expand Down Expand Up @@ -31,13 +32,25 @@ func (c cond) Create(io runtime.IO, _ runtime.Msg) (func(ctx context.Context), e

return func(ctx context.Context) {
for {
dataMsg, ok := dataIn.Receive(ctx)
if !ok {
return
}
var dataMsg, ifMsg runtime.Msg
var dataOk, ifOk bool

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
dataMsg, dataOk = dataIn.Receive(ctx)
}()

go func() {
defer wg.Done()
ifMsg, ifOk = ifIn.Receive(ctx)
}()

wg.Wait()

ifMsg, ok := ifIn.Receive(ctx)
if !ok {
if !dataOk || !ifOk {
return
}

Expand Down
31 changes: 22 additions & 9 deletions internal/runtime/funcs/int_add_reducer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package funcs

import (
"context"
"sync"

"github.com/nevalang/neva/internal/runtime"
)
Expand All @@ -12,12 +13,12 @@ func (intAddReducer) Create(
io runtime.IO,
_ runtime.Msg,
) (func(ctx context.Context), error) {
firstIn, err := io.In.Single("first")
accIn, err := io.In.Single("acc")
if err != nil {
return nil, err
}

secondIn, err := io.In.Single("second")
elIn, err := io.In.Single("el")
if err != nil {
return nil, err
}
Expand All @@ -29,17 +30,29 @@ func (intAddReducer) Create(

return func(ctx context.Context) {
for {
firstMsg, ok := firstIn.Receive(ctx)
if !ok {
return
}
var accMsg, elMsg runtime.Msg
var accOk, elOk bool

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
accMsg, accOk = accIn.Receive(ctx)
}()

go func() {
defer wg.Done()
elMsg, elOk = elIn.Receive(ctx)
}()

wg.Wait()

secondMsg, ok := secondIn.Receive(ctx)
if !ok {
if !accOk || !elOk {
return
}

resMsg := runtime.NewIntMsg(firstMsg.Int() + secondMsg.Int())
resMsg := runtime.NewIntMsg(elMsg.Int() + accMsg.Int())
if !resOut.Send(ctx, resMsg) {
return
}
Expand Down
3 changes: 3 additions & 0 deletions internal/runtime/funcs/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,8 @@ func CreatorRegistry() map[string]runtime.FuncCreator {

// sync
"wait_group": waitGroup{},

// other
"accumulator": accumulator{},
}
}
1 change: 1 addition & 0 deletions std/builtin/math.neva
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ pub flow Mod(num int, den int) (res int, err error)

// === Reducer Interface Experiment ===

// IDEA: res->acc?
#extern(int int_add_reducer, float float_add_reducer, string string_add_reducer)
pub flow AddReducer<T int | float | string>(acc T, el T) (res T)
40 changes: 36 additions & 4 deletions std/builtin/streams.neva
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ pub flow Range(from int, to int) (data stream<int>)
// StreamPort iterates over all array-inport's slots in order
// and produces a stream of messages.
#extern(array_port_to_stream)
pub flow StreamPort<T>([port] T) (seq stream<T>)
pub flow StreamPort<T>([port] T) (data stream<T>) // TODO data->res

// IPortReducer reduces a stream of messages to one single message.
// It's expected to send a result message after every processed stream.
pub interface IPortReducer<T>(seq stream<T>) (res T)
pub interface IPortReducer<T>(data stream<T>) (res T)

// ReducePort reduces messages from multiple connections to a single message.
// It iterates over all array-inport's slots in order and streams every message
Expand Down Expand Up @@ -76,7 +76,6 @@ pub interface IFilterHandler<T>(data T) (res bool)
// It's possible to chain multiple Filters together.
pub flow Filter<T>(data stream<T>) (res stream<T>) {
Cond<stream<T>>
Del
FanOut<stream<T>>
handler IFilterHandler<T>

Expand All @@ -88,9 +87,42 @@ pub flow Filter<T>(data stream<T>) (res stream<T>) {
fanOut[1].data -> handler -> cond:if

cond:then -> :res
cond:else -> del
}

interface IReduceHandler<T, Y>(acc T, el T) (res Y) // TODO res->acc?

// Reduce applies a reduction function to a stream of values, accumulating the result.
// It takes an initial value and a stream of data, and produces a single result.
pub flow Reduce<T, Y>(data stream<T>, init Y) (res Y) {
handler IReduceHandler<T, Y>
fanOut FanOut<stream<T>>
acc Accumulator<Y>

---

:init -> acc:init

:data -> fanOut
fanOut[0].data -> handler:el
fanOut[1].last -> acc:last

acc:cur -> handler:acc
handler:res -> acc:upd

acc:res -> :res
}

// IDEA:
// pub type AccumulatorUpdate<T> strcut {
// data T
// last bool
// }

// Accumulator maintains the current state of the reduction.
// It updates its value with each new input and outputs the final result when last is true.
#extern(accumulator)
pub flow Accumulator<T>(init T, upd T, last bool) (cur T, res T)

// --- For ---

// IForHandler is a dependency for For flow.
Expand Down

0 comments on commit 536ec23

Please sign in to comment.