Skip to content

Commit

Permalink
Fix LoopbackClient issues (#6481)
Browse files Browse the repository at this point in the history
This PR fixes 2 bugs in the LoopbackClient:

1. If the server doesn't spin up immediately, getting the redirectUri
will fail. Addressed by moving the setInterval out of the
`listenForAuthCode` function and instead wrapping the `getRedirectUri`
with it
1. If the listener throws before the promise is awaited the rejection
becomes unhandled. Addressed by attaching the catch handler on promise
creation and storing the error for later.
  • Loading branch information
tnorling authored Sep 28, 2023
1 parent 781325b commit b69ed83
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"type": "minor",
"comment": "Fix unhandled rejections & problems getting redirectUri in LoopbackClient #6481",
"packageName": "@azure/msal-node",
"email": "[email protected]",
"dependentChangeType": "patch"
}
82 changes: 69 additions & 13 deletions lib/msal-node/src/client/PublicClientApplication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
* Licensed under the MIT License.
*/

import { ApiId, Constants } from "../utils/Constants.js";
import {
ApiId,
Constants,
LOOPBACK_SERVER_CONSTANTS,
} from "../utils/Constants.js";
import {
AuthenticationResult,
CommonDeviceCodeRequest,
Expand All @@ -26,7 +30,7 @@ import { DeviceCodeRequest } from "../request/DeviceCodeRequest.js";
import { AuthorizationUrlRequest } from "../request/AuthorizationUrlRequest.js";
import { AuthorizationCodeRequest } from "../request/AuthorizationCodeRequest.js";
import { InteractiveRequest } from "../request/InteractiveRequest.js";
import { NodeAuthError } from "../error/NodeAuthError.js";
import { NodeAuthError, NodeAuthErrorMessage } from "../error/NodeAuthError.js";
import { LoopbackClient } from "../network/LoopbackClient.js";
import { SilentFlowRequest } from "../request/SilentFlowRequest.js";
import { SignOutRequest } from "../request/SignOutRequest.js";
Expand Down Expand Up @@ -167,13 +171,21 @@ export class PublicClientApplication
const loopbackClient: ILoopbackClient =
customLoopbackClient || new LoopbackClient();

let authCodeListener: Promise<ServerAuthorizationCodeResponse>;
let authCodeResponse: ServerAuthorizationCodeResponse = {};
let authCodeListenerError: AuthError | null = null;
try {
authCodeListener = loopbackClient.listenForAuthCode(
successTemplate,
errorTemplate
);
const redirectUri = loopbackClient.getRedirectUri();
const authCodeListener = loopbackClient
.listenForAuthCode(successTemplate, errorTemplate)
.then((response) => {
authCodeResponse = response;
})
.catch((e) => {
// Store the promise instead of throwing so we can control when its thrown
authCodeListenerError = e;
});

// Wait for server to be listening
const redirectUri = await this.waitForRedirectUri(loopbackClient);

const validRequest: AuthorizationUrlRequest = {
...remainingProperties,
Expand All @@ -187,9 +199,10 @@ export class PublicClientApplication

const authCodeUrl = await this.getAuthCodeUrl(validRequest);
await openBrowser(authCodeUrl);
const authCodeResponse = await authCodeListener.finally(() => {
loopbackClient.closeServer();
});
await authCodeListener;
if (authCodeListenerError) {
throw authCodeListenerError;
}

if (authCodeResponse.error) {
throw new ServerError(
Expand All @@ -209,9 +222,8 @@ export class PublicClientApplication
...validRequest,
};
return this.acquireTokenByCode(tokenRequest);
} catch (e) {
} finally {
loopbackClient.closeServer();
throw e;
}
}

Expand Down Expand Up @@ -280,4 +292,48 @@ export class PublicClientApplication

return this.getTokenCache().getAllAccounts();
}

/**
* Attempts to retrieve the redirectUri from the loopback server. If the loopback server does not start listening for requests within the timeout this will throw.
* @param loopbackClient
* @returns
*/
private async waitForRedirectUri(
loopbackClient: ILoopbackClient
): Promise<string> {
return new Promise<string>((resolve, reject) => {
let ticks = 0;
const id = setInterval(() => {
if (
LOOPBACK_SERVER_CONSTANTS.TIMEOUT_MS /
LOOPBACK_SERVER_CONSTANTS.INTERVAL_MS <
ticks
) {
clearInterval(id);
reject(NodeAuthError.createLoopbackServerTimeoutError());
return;
}

try {
const r = loopbackClient.getRedirectUri();
clearInterval(id);
resolve(r);
return;
} catch (e) {
if (
e instanceof AuthError &&
e.errorCode ===
NodeAuthErrorMessage.noLoopbackServerExists.code
) {
// Loopback server is not listening yet
ticks++;
return;
}
clearInterval(id);
reject(e);
return;
}
}, LOOPBACK_SERVER_CONSTANTS.INTERVAL_MS);
});
}
}
51 changes: 18 additions & 33 deletions lib/msal-node/src/network/LoopbackClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import {
} from "@azure/msal-common";
import http from "http";
import { NodeAuthError } from "../error/NodeAuthError.js";
import { Constants, LOOPBACK_SERVER_CONSTANTS } from "../utils/Constants.js";
import { Constants } from "../utils/Constants.js";
import { ILoopbackClient } from "./ILoopbackClient.js";

export class LoopbackClient implements ILoopbackClient {
private server: http.Server;
private server: http.Server | undefined;

/**
* Spins up a loopback server which returns the server response when the localhost redirectUri is hit
Expand All @@ -27,17 +27,14 @@ export class LoopbackClient implements ILoopbackClient {
successTemplate?: string,
errorTemplate?: string
): Promise<ServerAuthorizationCodeResponse> {
if (!!this.server) {
if (this.server) {
throw NodeAuthError.createLoopbackServerAlreadyExistsError();
}

const authCodeListener = new Promise<ServerAuthorizationCodeResponse>(
return new Promise<ServerAuthorizationCodeResponse>(
(resolve, reject) => {
this.server = http.createServer(
async (
req: http.IncomingMessage,
res: http.ServerResponse
) => {
(req: http.IncomingMessage, res: http.ServerResponse) => {
const url = req.url;
if (!url) {
res.end(
Expand All @@ -59,7 +56,7 @@ export class LoopbackClient implements ILoopbackClient {
const authCodeResponse =
UrlString.getDeserializedQueryString(url);
if (authCodeResponse.code) {
const redirectUri = await this.getRedirectUri();
const redirectUri = this.getRedirectUri();
res.writeHead(HttpStatus.REDIRECT, {
location: redirectUri,
}); // Prevent auth code from being saved in the browser history
Expand All @@ -71,36 +68,14 @@ export class LoopbackClient implements ILoopbackClient {
this.server.listen(0); // Listen on any available port
}
);

// Wait for server to be listening
await new Promise<void>((resolve) => {
let ticks = 0;
const id = setInterval(() => {
if (
LOOPBACK_SERVER_CONSTANTS.TIMEOUT_MS /
LOOPBACK_SERVER_CONSTANTS.INTERVAL_MS <
ticks
) {
throw NodeAuthError.createLoopbackServerTimeoutError();
}

if (this.server.listening) {
clearInterval(id);
resolve();
}
ticks++;
}, LOOPBACK_SERVER_CONSTANTS.INTERVAL_MS);
});

return authCodeListener;
}

/**
* Get the port that the loopback server is running on
* @returns
*/
getRedirectUri(): string {
if (!this.server) {
if (!this.server || !this.server.listening) {
throw NodeAuthError.createNoLoopbackServerExistsError();
}

Expand All @@ -119,8 +94,18 @@ export class LoopbackClient implements ILoopbackClient {
* Close the loopback server
*/
closeServer(): void {
if (!!this.server) {
if (this.server) {
// Only stops accepting new connections, server will close once open/idle connections are closed.
this.server.close();

if (typeof this.server.closeAllConnections === "function") {
/*
* Close open/idle connections. This API is available in Node versions 18.2 and higher
*/
this.server.closeAllConnections();
}
this.server.unref();
this.server = undefined;
}
}
}
139 changes: 138 additions & 1 deletion lib/msal-node/test/client/PublicClientApplication.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,77 @@ describe("PublicClientApplication", () => {
expect(response.account).toEqual(mockAuthenticationResult.account);
});

test("acquireTokenInteractive - getting redirectUri waits for server to start", async () => {
const authApp = new PublicClientApplication(appConfig);

let redirectUri: string;

// mock listener to wait 2 seconds before starting server
let originalListen = LoopbackClient.prototype.listenForAuthCode;
const listenerSpy = jest.spyOn(
LoopbackClient.prototype,
"listenForAuthCode"
);
listenerSpy.mockImplementation(() => {
return new Promise<void>((resolve) => {
setTimeout(() => {
resolve();
}, 2000);
}).then(
() => originalListen.call(listenerSpy.mock.instances[0]) // call original function and pass in the 'this' context
);
});

const openBrowser = (url: string) => {
expect(
url.startsWith("https://login.microsoftonline.com")
).toBe(true);
http.get(
`${redirectUri}?code=${TEST_CONSTANTS.AUTHORIZATION_CODE}`
);
return Promise.resolve();
};
const request: InteractiveRequest = {
scopes: TEST_CONSTANTS.DEFAULT_GRAPH_SCOPE,
openBrowser: openBrowser,
};

const MockAuthorizationCodeClient =
getMsalCommonAutoMock().AuthorizationCodeClient;
jest.spyOn(
msalCommon,
"AuthorizationCodeClient"
).mockImplementation(
(config) => new MockAuthorizationCodeClient(config)
);

jest.spyOn(
MockAuthorizationCodeClient.prototype,
"getAuthCodeUrl"
).mockImplementation((req) => {
redirectUri = req.redirectUri;
return Promise.resolve(TEST_CONSTANTS.AUTH_CODE_URL);
});

jest.spyOn(
MockAuthorizationCodeClient.prototype,
"acquireToken"
).mockImplementation((tokenRequest) => {
expect(tokenRequest.scopes).toEqual([
...TEST_CONSTANTS.DEFAULT_GRAPH_SCOPE,
...TEST_CONSTANTS.DEFAULT_OIDC_SCOPES,
]);
return Promise.resolve(mockAuthenticationResult);
});

const response = await authApp.acquireTokenInteractive(request);
expect(response.idToken).toEqual(mockAuthenticationResult.idToken);
expect(response.accessToken).toEqual(
mockAuthenticationResult.accessToken
);
expect(response.account).toEqual(mockAuthenticationResult.account);
});

test("acquireTokenInteractive - with custom loopback client succeeds", async () => {
const authApp = new PublicClientApplication(appConfig);

Expand Down Expand Up @@ -523,7 +594,7 @@ describe("PublicClientApplication", () => {
});
});

test("acquireTokenInteractive - loopback server is closed on error", async () => {
test("acquireTokenInteractive - loopback server is closed on error", (done) => {
const authApp = new PublicClientApplication(appConfig);

const openBrowser = (url: string) => {
Expand Down Expand Up @@ -583,6 +654,72 @@ describe("PublicClientApplication", () => {
authApp.acquireTokenInteractive(request).catch((e) => {
expect(e).toBe("Browser open error");
expect(mockCloseServer).toHaveBeenCalledTimes(1);
done();
});
});

test("acquireTokenInteractive - authCode listener rejections are handled", (done) => {
const authApp = new PublicClientApplication(appConfig);

const openBrowser = (url: string) => {
expect(
url.startsWith("https://login.microsoftonline.com")
).toBe(true);
return Promise.resolve();
};

// mock listener to wait 2 seconds then throw
let originalListen = LoopbackClient.prototype.listenForAuthCode;
const listenerSpy = jest.spyOn(
LoopbackClient.prototype,
"listenForAuthCode"
);
listenerSpy.mockImplementation(async () => {
return new Promise((resolve, reject) => {
setTimeout(() => {
reject("listener error");
}, 2000);
originalListen
.call(listenerSpy.mock.instances[0]) // call original function and pass in the 'this' context
.then((result) => resolve(result)); // This should never be called because the server will never be hit
});
});

jest.spyOn(
LoopbackClient.prototype,
"getRedirectUri"
).mockImplementation(() => TEST_CONSTANTS.REDIRECT_URI);
const mockCloseServer = jest.spyOn(
LoopbackClient.prototype,
"closeServer"
);

const request: InteractiveRequest = {
scopes: TEST_CONSTANTS.DEFAULT_GRAPH_SCOPE,
openBrowser: openBrowser,
};

const MockAuthorizationCodeClient =
getMsalCommonAutoMock().AuthorizationCodeClient;
jest.spyOn(
msalCommon,
"AuthorizationCodeClient"
).mockImplementation(
(config) => new MockAuthorizationCodeClient(config)
);

jest.spyOn(
MockAuthorizationCodeClient.prototype,
"getAuthCodeUrl"
).mockImplementation((req) => {
expect(req.redirectUri).toEqual(TEST_CONSTANTS.REDIRECT_URI);
return Promise.resolve(TEST_CONSTANTS.AUTH_CODE_URL);
});

authApp.acquireTokenInteractive(request).catch((e) => {
expect(e).toBe("listener error");
expect(mockCloseServer).toHaveBeenCalled();
done();
});
});
});
Expand Down

0 comments on commit b69ed83

Please sign in to comment.