diff --git a/.changeset/sweet-sheep-arrive.md b/.changeset/sweet-sheep-arrive.md new file mode 100644 index 0000000000..2a0a9882c4 --- /dev/null +++ b/.changeset/sweet-sheep-arrive.md @@ -0,0 +1,13 @@ +--- +'xstate': minor +--- + +`fromPromise` now passes a signal into its creator function. + +```ts +const logic = fromPromise(({ signal }) => + fetch('https://api.example.com', { signal }) +); +``` + +This will be called whenever the state transitions before the promise is resolved. This is useful for cancelling the promise if the state changes. diff --git a/packages/core/src/actors/promise.ts b/packages/core/src/actors/promise.ts index 036ef9b1ae..a1fc89502f 100644 --- a/packages/core/src/actors/promise.ts +++ b/packages/core/src/actors/promise.ts @@ -3,6 +3,7 @@ import { AnyActorSystem } from '../system.ts'; import { ActorLogic, ActorRefFrom, + AnyActorRef, EventObject, NonReducibleUnknown, Snapshot @@ -71,6 +72,9 @@ export type PromiseActorRef = ActorRefFrom< * // } * ``` */ + +const controllerMap = new WeakMap(); + export function fromPromise( promiseCreator: ({ input, @@ -88,11 +92,12 @@ export function fromPromise( * The parent actor of the promise actor */ self: PromiseActorRef; + signal: AbortSignal; }) => PromiseLike ): PromiseActorLogic { const logic: PromiseActorLogic = { config: promiseCreator, - transition: (state, event) => { + transition: (state, event, scope) => { if (state.status !== 'active') { return state; } @@ -114,12 +119,14 @@ export function fromPromise( error: (event as any).data, input: undefined }; - case XSTATE_STOP: + case XSTATE_STOP: { + controllerMap.get(scope.self)?.abort(); return { ...state, status: 'stopped', input: undefined }; + } default: return state; } @@ -130,9 +137,15 @@ export function fromPromise( if (state.status !== 'active') { return; } - + const controller = new AbortController(); + controllerMap.set(self, controller); const resolvedPromise = Promise.resolve( - promiseCreator({ input: state.input!, system, self }) + promiseCreator({ + input: state.input!, + system, + self, + signal: controller.signal + }) ); resolvedPromise.then( @@ -140,6 +153,7 @@ export function fromPromise( if (self.getSnapshot().status !== 'active') { return; } + controllerMap.delete(self); system._relay(self, self, { type: XSTATE_PROMISE_RESOLVE, data: response @@ -149,6 +163,7 @@ export function fromPromise( if (self.getSnapshot().status !== 'active') { return; } + controllerMap.delete(self); system._relay(self, self, { type: XSTATE_PROMISE_REJECT, data: errorData diff --git a/packages/core/test/actorLogic.test.ts b/packages/core/test/actorLogic.test.ts index 9cba36d698..7f37a7f145 100644 --- a/packages/core/test/actorLogic.test.ts +++ b/packages/core/test/actorLogic.test.ts @@ -232,6 +232,238 @@ describe('promise logic (fromPromise)', () => { createActor(promiseLogic).start(); }); + + it('should abort when stopping', async () => { + const deferred = withResolvers(); + const fn = jest.fn(); + const promiseLogic = fromPromise((ctx) => { + return new Promise((res) => { + ctx.signal.addEventListener('abort', fn); + }); + }); + + const actor = createActor(promiseLogic).start(); + + actor.stop(); + + deferred.resolve(42); + await deferred.promise; + expect(fn).toHaveBeenCalled(); + }); + + it('should not abort when stopped if promise is resolved/rejected', async () => { + const resolvedDeferred = withResolvers(); + const resolvedSignalListener = jest.fn(); + const resolvedPromiseLogic = fromPromise((ctx) => { + ctx.signal.addEventListener('abort', resolvedSignalListener); + return resolvedDeferred.promise; + }); + + const rejectedDeferred = withResolvers(); + const rejectedSignalListener = jest.fn(); + const rejectedPromiseLogic = fromPromise((ctx) => { + ctx.signal.addEventListener('abort', rejectedSignalListener); + return rejectedDeferred.promise.catch(() => {}); + }); + + const actor = createActor(resolvedPromiseLogic).start(); + resolvedDeferred.resolve(42); + await waitFor(actor, (s) => s.status === 'done'); + actor.stop(); + expect(resolvedSignalListener).not.toHaveBeenCalled(); + + const actor2 = createActor(rejectedPromiseLogic).start(); + + rejectedDeferred.reject(50); + await rejectedDeferred.promise.catch(() => {}); + await waitFor(actor2, (s) => s.status === 'done'); + actor2.stop(); + expect(rejectedSignalListener).not.toHaveBeenCalled(); + }); + + it('should not reuse the same signal for different actors with same logic', async () => { + let deferredMap: Map> = new Map(); + let signalListenerMap: Map = new Map(); + const p = fromPromise(({ self, signal }) => { + const deferred = withResolvers(); + const signalListener = jest.fn(); + deferredMap.set(self.id, deferred); + signalListenerMap.set(self.id, signalListener); + signal.addEventListener('abort', signalListener); + return deferred.promise; + }); + const machine = createMachine({ + type: 'parallel', + states: { + p1: { + initial: 'running', + states: { + running: { + invoke: { + src: p, + id: 'p1' + }, + on: { + CANCEL_1: 'canceled' + } + }, + canceled: {} + } + }, + p2: { + initial: 'running', + states: { + running: { + invoke: { + src: p, + id: 'p2', + onDone: 'done' + } + }, + done: {} + } + } + } + }); + const actor = createActor(machine).start(); + + const p1Deferred = deferredMap.get('p1')!; + const p2Deferred = deferredMap.get('p2')!; + + actor.send({ type: 'CANCEL_1' }); + p1Deferred.resolve(42); + p2Deferred.resolve(42); + await Promise.all([ + waitFor(actor, (s) => s.matches('p1.canceled')), + waitFor(actor, (s) => s.matches('p2.done')) + ]); + expect(signalListenerMap.get('p1')).toHaveBeenCalled(); + expect(signalListenerMap.get('p2')).not.toHaveBeenCalled(); + }); + + it('should not reuse the same signal for different actors with same logic and id', async () => { + let deferredList: PromiseWithResolvers[] = []; + let signalListenerList: jest.Mock[] = []; + const p = fromPromise(({ signal }) => { + const deferred = withResolvers(); + const fn = jest.fn(); + deferredList.push(deferred); + signalListenerList.push(fn); + signal.addEventListener('abort', fn); + return deferred.promise; + }); + const machine = createMachine({ + type: 'parallel', + states: { + p1: { + initial: 'running', + states: { + running: { + invoke: { + src: p, + id: 'p' + }, + on: { + CANCEL_1: 'canceled' + } + }, + canceled: {} + } + }, + p2: { + initial: 'running', + states: { + running: { + invoke: { + src: p, + id: 'p', + onDone: 'done' + } + }, + done: {} + } + } + } + }); + const actor = createActor(machine).start(); + + const p1Deferred = deferredList[0]; + const p2Deferred = deferredList[1]; + const p1Fn = signalListenerList[0]; + const p2Fn = signalListenerList[1]; + + actor.send({ type: 'CANCEL_1' }); + p1Deferred.resolve(42); + p2Deferred.resolve(42); + + await Promise.all([ + waitFor(actor, (s) => s.matches('p1.canceled')), + waitFor(actor, (s) => s.matches('p2.done')) + ]); + + expect(p1Fn).toHaveBeenCalled(); + expect(p2Fn).not.toHaveBeenCalled(); + }); + + it('should not reuse the same signal for the same actor when restarted', async () => { + let deferredList: PromiseWithResolvers[] = []; + let signalListenerList: jest.Mock[] = []; + const p = fromPromise(({ signal }) => { + const deferred = withResolvers(); + const fn = jest.fn(); + deferredList.push(deferred); + signalListenerList.push(fn); + signal.addEventListener('abort', fn); + return deferred.promise; + }); + const machine = createMachine({ + initial: 'running', + states: { + running: { + invoke: { + src: p, + id: 'p', + onDone: 'done' + }, + on: { + cancel: 'canceled' + } + }, + done: { + on: { + restart: 'running' + } + }, + canceled: { + on: { + restart: 'running' + } + } + } + }); + const actor = createActor(machine).start(); + + // resolve the first promise and no canceling + await waitFor(actor, (s) => s.matches('running')); + const deferred1 = deferredList[0]; + const fn1 = signalListenerList[0]; + deferred1.resolve(42); + await waitFor(actor, (s) => s.matches('done')); + expect(fn1).not.toHaveBeenCalled(); + + actor.send({ type: 'restart' }); + + // cancel while running + await waitFor(actor, (s) => s.matches('running')); + actor.send({ type: 'cancel' }); + await waitFor(actor, (s) => s.matches('canceled')); + + const deferred2 = deferredList[1]; + deferred2.resolve(42); + await deferred2.promise; + const fn2 = signalListenerList[1]; + expect(fn2).toHaveBeenCalled(); + }); }); describe('transition function logic (fromTransition)', () => { @@ -1032,3 +1264,19 @@ describe('composable actor logic', () => { ); }); }); + +function withResolvers(): PromiseWithResolvers { + let resolve: (value: T | PromiseLike) => void; + let reject: (reason: any) => void; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + return { + resolve: resolve!, + reject: reject!, + promise + }; +} diff --git a/tsconfig.json b/tsconfig.json index e251b6d7a5..bc744106d1 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -9,7 +9,7 @@ "skipLibCheck": true, "resolveJsonModule": true, "jsx": "react-jsx", - "lib": ["es2019", "dom"], + "lib": ["es2019", "ESNext.Promise", "dom"], "strict": true, "stripInternal": true },