From 80b7305fb58052c65f57292b1d2a14312d643532 Mon Sep 17 00:00:00 2001 From: Ore Ogundipe Date: Fri, 6 Sep 2024 08:49:05 -0400 Subject: [PATCH] add: support for visual oddball analysis --- analysis_api/app.py | 151 ++++++++++++++++++++++++++++-- frontend/src/pages/analysis.tsx | 7 +- frontend/src/pages/recordings.tsx | 4 +- 3 files changed, 153 insertions(+), 9 deletions(-) diff --git a/analysis_api/app.py b/analysis_api/app.py index a704c7d8..4546c4f3 100644 --- a/analysis_api/app.py +++ b/analysis_api/app.py @@ -9,6 +9,9 @@ from flask_cors import CORS import base64 import re +import functools +import numpy as np +from scipy import stats import eeg import cocoa_pad @@ -22,6 +25,24 @@ def encode_image_to_base64(image_path): encoded_string = base64.b64encode(image_file.read()).decode('utf-8') return encoded_string +def compute_confidence_interval(data, confidence=0.95): + """ + Compute the confidence interval for the provided data. + + Parameters: + - data (np.array): Array of data points. + - confidence (float): Confidence level (default is 0.95). + + Returns: + - ci_lower (np.array): Lower bound of the confidence interval. + - ci_upper (np.array): Upper bound of the confidence interval. + """ + n = data.shape[0] + m = np.mean(data, axis=0) + se = stats.sem(data, axis=0) + h = se * stats.t.ppf((1 + confidence) / 2., n - 1) + return m - h, m + h + # TODO: handle multiple files @app.route('/api/v1/process_eeg', methods=['POST']) def process_eeg(): @@ -94,11 +115,131 @@ def process_eeg(): print("error", e) return jsonify({'error': str(e)}), 500 -# TODO: endpoint for ERP analysis -@app.route('/api/v1/process_eeg_erp', methods=['POST']) -def process_eeg_erp(): +def validate_eeg_file(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if 'eegFile' not in request.files: + return jsonify({'error': 'No EEG file submitted for processing'}), 400 + return func(*args, **kwargs) + return wrapper + +@app.route('/api/v1/process_visual_oddball', methods=['POST']) +@validate_eeg_file +def process_visual_oddball(): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import mne + import json + import io try: - return jsonify({'response': "works perfect"}), 200 + eegFile = request.files['eegFile'] + samplingFrequency = int(request.form['samplingFrequency']) + stimulusFile = request.files['stimulusFile'] + + # assumptions, eeg file is csv.. situmuls file is .json + eeg_df = pd.read_csv(eegFile) + eeg_df.drop(columns=['index'], inplace=True) + sfreq = samplingFrequency + + # Find periods of interest that overlap both eeg and stimulus + eeg_timestamps = eeg_df['unixTimestamp'].tolist() + eeg_timestamps_range = (min(eeg_timestamps), max(eeg_timestamps)) + stimulus_json = json.loads(stimulusFile.read().decode('utf-8')) + filtered_json_events = [trial for trial in stimulus_json['trials'] if 'unixTimestamp' in trial and + eeg_timestamps_range[0] <= trial['unixTimestamp'] <= eeg_timestamps_range[1]] + if len(filtered_json_events) == 0: + raise ValueError("No valid events found after filtering with the CSV timestamps range.") + + # Create MNE events array from filtered JSON events + event_id = {'standard': 1, 'oddball': 2} + events = [] + start_time = eeg_timestamps[0] / 1e3 # Convert to seconds #TODO: check if it's in milliseconds first + for trial in filtered_json_events: + if 'value' in trial: + if 'oddball' in trial['value']: + event_type = event_id['oddball'] + elif 'standard' in trial['value']: + event_type = event_id['standard'] + else: + continue + event_time = trial['unixTimestamp'] / 1e3 # Convert to seconds + event_sample = int((event_time - start_time) * sfreq) + events.append([event_sample, 0, event_type]) + + events = np.array(events) + if len(events) == 0: + raise ValueError("No valid events found for creating epochs.") + + # Create MNE Raw object + info = mne.create_info(ch_names=list(eeg_df.columns[1:]), sfreq=sfreq, ch_types='eeg') + eeg_df = eeg_df.values[:, 1:].T + eeg_df *= 1e-6 # convert from uV to V + raw = mne.io.RawArray(eeg_df, info) + raw.set_montage('standard_1020') + + lfreq = 1 + ufreq = 40 + + # Filter the data + raw.filter(lfreq, ufreq, fir_design='firwin') + + # Create epochs + epochs = mne.Epochs(raw, events, event_id, tmin=-0.2, tmax=0.8, baseline=(None, 0), preload=True) + + # Compute the average ERP for each condition + evoked_standard = epochs['standard'].average() + evoked_oddball = epochs['oddball'].average() + + print("ERP data computed.") + + # Plot the ERP for the standard and oddball stimulus + print("Plotting ERP for standard and oddball stimulus...") + fig, axes = plt.subplots(2, 2, figsize=(15, 10)) + fig.suptitle('ERP for Standard and Oddball Stimulus') + + for idx, ch_name in enumerate(evoked_standard.ch_names): + ax = axes[idx // 2, idx % 2] + + # Plot standard ERP + ax.plot(evoked_standard.times, evoked_standard.data[idx], + label='Standard', color='blue') + + # Plot oddball ERP + ax.plot(evoked_oddball.times, evoked_oddball.data[idx], + label='Oddball', color='red') + + # Calculate and plot the difference waveform (Oddball - Standard) + # difference_wave = evoked_oddball.data[idx] - evoked_standard.data[idx] + # ax.plot(evoked_standard.times, difference_wave, + # label='Difference (Oddball - Standard)', color='green') + + ax.axvline(0, color='k', linestyle='--', label='Stimulus Onset') + ax.set_title(f'Channel: {ch_name}') + ax.set_xlabel('Time (s)') + ax.set_ylabel('Amplitude (µV)') + ax.legend() + ax.grid(True) + + plt.tight_layout() + + # Save the plot to a BytesIO object + img_buffer = io.BytesIO() + plt.savefig(img_buffer, format='png') + img_buffer.seek(0) + + # Encode the image to base64 + img_str = base64.b64encode(img_buffer.getvalue()).decode() + + # Create a dictionary with the image data + erp_plot = { + "key": "ERP Standard vs Oddball (All Channels)", + "value": f"data:image/png;base64,{img_str}", + "summary": "ERP plot for standard and oddball stimulus across all channels.\n Filtered between 1 and 40 Hz. Using MNE firwin" + } + plt.close() + + return jsonify({'images': [erp_plot], 'summary': "ERP Standard vs Oddball (All Channels)"}), 200 except Exception as e: return jsonify({'error': 'error processing', 'message': e}), 500 @@ -124,8 +265,6 @@ def process_eeg_fooof(): eegFile = request.files['eegFile'] samplingFrequency = int(request.form['samplingFrequency']) - print("eegFile", eegFile) - # Check if the file has a filename if eegFile.filename == '': return jsonify({'error': 'No selected EEG file'}), 400 diff --git a/frontend/src/pages/analysis.tsx b/frontend/src/pages/analysis.tsx index 7d007702..8018541f 100644 --- a/frontend/src/pages/analysis.tsx +++ b/frontend/src/pages/analysis.tsx @@ -60,7 +60,8 @@ const AnalysisPage: NextPage = () => { try { setLoading(true); - const response = await fetch(`${process.env["NEXT_PUBLIC_ANALYSIS_SERVER_URL"]}/api/v1/process_eeg_fooof`, { + const urlEndpoint = stimulusFile ? "process_visual_oddball" : "process_eeg_fooof"; + const response = await fetch(`${process.env["NEXT_PUBLIC_ANALYSIS_SERVER_URL"]}/api/v1/${urlEndpoint}`, { method: "POST", body: formData, }); @@ -233,9 +234,11 @@ export const getServerSideProps: GetServerSideProps = async ({ req, res }) => { const session = await getServerSession(req, res, authOptions); if (!session) { + // login the user + const currentUrl = `${req.url}`; return { redirect: { - destination: "/auth/login", + destination: `/auth/login?callbackUrl=${encodeURIComponent(currentUrl)}`, permanent: false, }, }; diff --git a/frontend/src/pages/recordings.tsx b/frontend/src/pages/recordings.tsx index 3d491007..98fca320 100644 --- a/frontend/src/pages/recordings.tsx +++ b/frontend/src/pages/recordings.tsx @@ -74,9 +74,11 @@ export const getServerSideProps: GetServerSideProps = async ({ req, res }) => { const session = await getServerSession(req, res, authOptions); if (!session) { + // login the user + const currentUrl = `${req.url}`; return { redirect: { - destination: "/auth/login", + destination: `/auth/login?callbackUrl=${encodeURIComponent(currentUrl)}`, permanent: false, }, };