diff --git a/src/components/EnterpriseApp/EnterpriseAppContextProvider.jsx b/src/components/EnterpriseApp/EnterpriseAppContextProvider.jsx
index 9d7960ae21..39cbb05f05 100644
--- a/src/components/EnterpriseApp/EnterpriseAppContextProvider.jsx
+++ b/src/components/EnterpriseApp/EnterpriseAppContextProvider.jsx
@@ -1,10 +1,12 @@
-import React, { createContext, useMemo } from 'react';
+import React, { createContext, useContext, useMemo } from 'react';
import PropTypes from 'prop-types';
+import { AppContext } from '@edx/frontend-platform/react';
import { EnterpriseSubsidiesContext, useEnterpriseSubsidiesContext } from '../EnterpriseSubsidiesContext';
import { SubsidyRequestsContext, useSubsidyRequestsContext } from '../subsidy-requests/SubsidyRequestsContext';
import {
useEnterpriseCurationContext,
+ useUpdateActiveEnterpriseForUser,
} from './data/hooks';
import EnterpriseAppSkeleton from './EnterpriseAppSkeleton';
@@ -49,6 +51,7 @@ const EnterpriseAppContextProvider = ({
enablePortalLearnerCreditManagementScreen,
children,
}) => {
+ const { authenticatedUser } = useContext(AppContext);
// subsidies for the enterprise customer
const enterpriseSubsidiesContext = useEnterpriseSubsidiesContext({
enterpriseId,
@@ -68,10 +71,16 @@ const EnterpriseAppContextProvider = ({
curationTitleForCreation: enterpriseName,
});
+ const { isLoading: isUpdatingActiveEnterprise } = useUpdateActiveEnterpriseForUser({
+ enterpriseId,
+ user: authenticatedUser,
+ });
+
const isLoading = (
subsidyRequestsContext.isLoading
|| enterpriseSubsidiesContext.isLoading
|| enterpriseCurationContext.isLoading
+ || isUpdatingActiveEnterprise
);
// [tech debt] consolidate the other context values (e.g., useSubsidyRequestsContext)
diff --git a/src/components/EnterpriseApp/EnterpriseAppContextProvider.test.jsx b/src/components/EnterpriseApp/EnterpriseAppContextProvider.test.jsx
index 892200045b..d868fc11df 100644
--- a/src/components/EnterpriseApp/EnterpriseAppContextProvider.test.jsx
+++ b/src/components/EnterpriseApp/EnterpriseAppContextProvider.test.jsx
@@ -23,31 +23,43 @@ describe('', () => {
isLoadingEnterpriseSubsidies: true,
isLoadingSubsidyRequests: false,
isLoadingEnterpriseCuration: false,
+ isLoadingUpdateActiveEnterpriseForUser: false,
},
{
isLoadingEnterpriseSubsidies: false,
isLoadingSubsidyRequests: true,
isLoadingEnterpriseCuration: false,
+ isLoadingUpdateActiveEnterpriseForUser: false,
},
{
isLoadingEnterpriseSubsidies: false,
isLoadingSubsidyRequests: false,
isLoadingEnterpriseCuration: true,
+ isLoadingUpdateActiveEnterpriseForUser: false,
},
{
isLoadingEnterpriseSubsidies: true,
isLoadingSubsidyRequests: true,
isLoadingEnterpriseCuration: false,
+ isLoadingUpdateActiveEnterpriseForUser: false,
+ },
+ {
+ isLoadingEnterpriseSubsidies: false,
+ isLoadingSubsidyRequests: false,
+ isLoadingEnterpriseCuration: false,
+ isLoadingUpdateActiveEnterpriseForUser: true,
},
{
isLoadingEnterpriseSubsidies: true,
isLoadingSubsidyRequests: true,
isLoadingEnterpriseCuration: true,
+ isLoadingUpdateActiveEnterpriseForUser: true,
},
])('renders when: %s', async ({
isLoadingEnterpriseSubsidies,
isLoadingSubsidyRequests,
isLoadingEnterpriseCuration,
+ isLoadingUpdateActiveEnterpriseForUser,
}) => {
const mockUseEnterpriseSubsidiesContext = jest.spyOn(enterpriseSubsidiesContext, 'useEnterpriseSubsidiesContext').mockReturnValue({
isLoading: isLoadingEnterpriseSubsidies,
@@ -62,6 +74,11 @@ describe('', () => {
isLoading: isLoadingEnterpriseCuration,
},
);
+ const mockUseUpdateActiveEnterpriseForUser = jest.spyOn(hooks, 'useUpdateActiveEnterpriseForUser').mockReturnValue(
+ {
+ isLoading: isLoadingUpdateActiveEnterpriseForUser,
+ },
+ );
render(
', () => {
);
await waitFor(() => {
+ expect(mockUseUpdateActiveEnterpriseForUser).toHaveBeenCalled();
expect(mockUseSubsidyRequestsContext).toHaveBeenCalled();
expect(mockUseEnterpriseSubsidiesContext).toHaveBeenCalled();
expect(mockUseEnterpriseCurationContext).toHaveBeenCalled();
- if (isLoadingEnterpriseSubsidies || isLoadingSubsidyRequests || isLoadingEnterpriseCuration) {
+ if (
+ isLoadingEnterpriseSubsidies
+ || isLoadingSubsidyRequests
+ || isLoadingEnterpriseCuration
+ || isLoadingUpdateActiveEnterpriseForUser
+ ) {
expect(screen.getByText('Loading...'));
} else {
expect(screen.getByText('children'));
diff --git a/src/components/EnterpriseApp/data/hooks/index.js b/src/components/EnterpriseApp/data/hooks/index.js
index 6285708f69..a89ad36440 100644
--- a/src/components/EnterpriseApp/data/hooks/index.js
+++ b/src/components/EnterpriseApp/data/hooks/index.js
@@ -1,2 +1,3 @@
export { default as useEnterpriseCuration } from './useEnterpriseCuration';
export { default as useEnterpriseCurationContext } from './useEnterpriseCurationContext';
+export { default as useUpdateActiveEnterpriseForUser } from './useUpdateActiveEnterpriseForUser';
diff --git a/src/components/EnterpriseApp/data/hooks/useUpdateActiveEnterpriseForUser.js b/src/components/EnterpriseApp/data/hooks/useUpdateActiveEnterpriseForUser.js
new file mode 100644
index 0000000000..d4121e0ee4
--- /dev/null
+++ b/src/components/EnterpriseApp/data/hooks/useUpdateActiveEnterpriseForUser.js
@@ -0,0 +1,29 @@
+import { logError } from '@edx/frontend-platform/logging';
+import {
+ useQuery,
+} from '@tanstack/react-query';
+
+import LmsApiService from '../../../../data/services/LmsApiService';
+
+const useUpdateActiveEnterpriseForUser = ({ enterpriseId, user }) => {
+ const { username } = user;
+ const { isLoading, error } = useQuery({
+ queryKey: ['updateUsersActiveEnterprise'],
+ queryFn: async () => {
+ await LmsApiService.getActiveLinkedEnterprise(username).then(async (linkedEnterprise) => {
+ if (linkedEnterprise.uuid !== enterpriseId) {
+ await LmsApiService.updateUserActiveEnterprise(enterpriseId);
+ }
+ });
+ return true;
+ },
+ });
+
+ if (error) { logError(`Could not set active enterprise for learner, failed with error: ${logError}`); }
+
+ return {
+ isLoading,
+ };
+};
+
+export default useUpdateActiveEnterpriseForUser;
diff --git a/src/components/EnterpriseApp/data/hooks/useUpdateActiveEnterpriseForUser.test.jsx b/src/components/EnterpriseApp/data/hooks/useUpdateActiveEnterpriseForUser.test.jsx
new file mode 100644
index 0000000000..a522c1d3f0
--- /dev/null
+++ b/src/components/EnterpriseApp/data/hooks/useUpdateActiveEnterpriseForUser.test.jsx
@@ -0,0 +1,75 @@
+import { renderHook } from '@testing-library/react-hooks';
+import { QueryClientProvider } from '@tanstack/react-query';
+import { logError } from '@edx/frontend-platform/logging';
+import { useUpdateActiveEnterpriseForUser } from './index';
+import LmsApiService from '../../../../data/services/LmsApiService';
+import { queryClient } from '../../../test/testUtils';
+
+jest.mock('../../../../data/services/LmsApiService');
+jest.mock('@edx/frontend-platform/logging', () => ({
+ ...jest.requireActual('@edx/frontend-platform/logging'),
+ logError: jest.fn(),
+}));
+
+describe('useUpdateActiveEnterpriseForUser', () => {
+ const wrapper = ({ children }) => (
+
+ {children}
+
+ );
+ const mockEnterpriseId = 'enterprise-uuid';
+ const mockUser = { username: 'joe_shmoe' };
+ const connectedEnterprise = 'someID';
+ beforeEach(() => {
+ LmsApiService.getActiveLinkedEnterprise.mockResolvedValue({ uuid: connectedEnterprise });
+ });
+
+ afterEach(() => jest.clearAllMocks());
+
+ it("should update user's active enterprise if it differs from the current enterprise", async () => {
+ const { result, waitForNextUpdate } = renderHook(
+ () => useUpdateActiveEnterpriseForUser({
+ enterpriseId: mockEnterpriseId,
+ user: mockUser,
+ }),
+ { wrapper },
+ );
+ expect(result.current.isLoading).toBe(true);
+
+ await waitForNextUpdate();
+
+ expect(LmsApiService.updateUserActiveEnterprise).toHaveBeenCalledTimes(1);
+ expect(result.current.isLoading).toBe(false);
+ });
+
+ it('should do nothing if active enterprise is the same as current enterprise', async () => {
+ // Pass the value of the enterprise ID returned by ``getActiveLinkedEnterprise`` to the hook
+ const { waitForNextUpdate } = renderHook(
+ () => useUpdateActiveEnterpriseForUser({
+ enterpriseId: connectedEnterprise,
+ user: mockUser,
+ }),
+ { wrapper },
+ );
+ await waitForNextUpdate();
+ expect(LmsApiService.updateUserActiveEnterprise).toHaveBeenCalledTimes(0);
+ });
+
+ it('should handle errors', async () => {
+ LmsApiService.updateUserActiveEnterprise.mockRejectedValueOnce(Error('uh oh'));
+ const { result, waitForNextUpdate } = renderHook(
+ () => useUpdateActiveEnterpriseForUser({
+ enterpriseId: mockEnterpriseId,
+ user: mockUser,
+ }),
+ { wrapper },
+ );
+ expect(result.current.isLoading).toBe(true);
+
+ await waitForNextUpdate();
+
+ expect(LmsApiService.updateUserActiveEnterprise).toHaveBeenCalledTimes(1);
+ expect(result.current.isLoading).toBe(false);
+ expect(logError).toHaveBeenCalledTimes(1);
+ });
+});
diff --git a/src/data/services/LmsApiService.js b/src/data/services/LmsApiService.js
index b3c32f662d..fc961ecc08 100644
--- a/src/data/services/LmsApiService.js
+++ b/src/data/services/LmsApiService.js
@@ -1,5 +1,6 @@
import { getAuthenticatedHttpClient } from '@edx/frontend-platform/auth';
import { camelCaseObject } from '@edx/frontend-platform/utils';
+import { logError } from '@edx/frontend-platform/logging';
import { configuration } from '../../config';
import generateFormattedStatusUrl from './apiServiceUtils';
@@ -384,6 +385,38 @@ class LmsApiService {
const url = `${LmsApiService.baseUrl}/enterprise/api/v1/analytics-summary/${enterpriseUUID}`;
return LmsApiService.apiClient().post(url, formData);
}
+
+ static updateUserActiveEnterprise = (enterpriseId) => {
+ const url = `${configuration.LMS_BASE_URL}/enterprise/select/active/`;
+ const formData = new FormData();
+ formData.append('enterprise', enterpriseId);
+
+ return LmsApiService.apiClient().post(
+ url,
+ formData,
+ );
+ };
+
+ static fetchEnterpriseLearnerData(options) {
+ const enterpriseLearnerUrl = `${configuration.LMS_BASE_URL}/enterprise/api/v1/enterprise-learner/`;
+ const queryParams = new URLSearchParams({
+ ...options,
+ page: 1,
+ });
+ const url = `${enterpriseLearnerUrl}?${queryParams.toString()}`;
+ return LmsApiService.apiClient().get(url);
+ }
+
+ static async getActiveLinkedEnterprise(username) {
+ const response = await this.fetchEnterpriseLearnerData({ username });
+ const transformedResponse = camelCaseObject(response.data);
+ const enterprisesForLearner = transformedResponse.results;
+ const activeLinkedEnterprise = enterprisesForLearner.find(enterprise => enterprise.active);
+ if (!activeLinkedEnterprise) {
+ logError(`${username} does not have any active linked enterprise customers`);
+ }
+ return activeLinkedEnterprise.enterpriseCustomer;
+ }
}
export default LmsApiService;
diff --git a/src/data/services/tests/LmsApiService.test.js b/src/data/services/tests/LmsApiService.test.js
index 2ec7007a54..396e903eac 100644
--- a/src/data/services/tests/LmsApiService.test.js
+++ b/src/data/services/tests/LmsApiService.test.js
@@ -7,12 +7,15 @@ import { configuration } from '../../../config';
const lmsBaseUrl = `${configuration.LMS_BASE_URL}`;
const mockEnterpriseUUID = 'test-enterprise-id';
+const mockUsername = 'test_username';
const axiosMock = new MockAdapter(axios);
getAuthenticatedHttpClient.mockReturnValue(axios);
axiosMock.onAny().reply(200);
axios.patch = jest.fn();
+axios.post = jest.fn();
+axios.get = jest.fn();
describe('LmsApiService', () => {
test('updateEnterpriseCustomer calls the LMS to update the enterprise customer', () => {
@@ -41,4 +44,33 @@ describe('LmsApiService', () => {
{ primary_color: '#A8DABC' },
);
});
+ test('updateUserActiveEnterprise calls the LMS to update the active linked enterprise org', () => {
+ LmsApiService.updateUserActiveEnterprise(
+ mockEnterpriseUUID,
+ );
+ const expectedFormData = new FormData();
+ expectedFormData.append('enterprise', mockEnterpriseUUID);
+ expect(axios.post).toBeCalledWith(
+ `${lmsBaseUrl}/enterprise/select/active/`,
+ expectedFormData,
+ );
+ });
+ test('fetchEnterpriseLearnerData calls the LMS to fetch learner data', () => {
+ LmsApiService.fetchEnterpriseLearnerData({ username: mockUsername });
+ expect(axios.get).toBeCalledWith(
+ `${lmsBaseUrl}/enterprise/api/v1/enterprise-learner/?username=${mockUsername}&page=1`,
+ );
+ });
+ test('getActiveLinkedEnterprise returns the actively linked enterprise', async () => {
+ axios.get.mockReturnValue({
+ data: {
+ results: [{
+ active: true,
+ enterpriseCustomer: { uuid: 'test-uuid' },
+ }],
+ },
+ });
+ const activeCustomer = await LmsApiService.getActiveLinkedEnterprise(mockUsername);
+ expect(activeCustomer).toEqual({ uuid: 'test-uuid' });
+ });
});