Skip to content

Commit

Permalink
add: support for visual oddball analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
oreHGA committed Sep 6, 2024
1 parent cb1cc65 commit 80b7305
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 9 deletions.
151 changes: 145 additions & 6 deletions analysis_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from flask_cors import CORS
import base64
import re
import functools
import numpy as np
from scipy import stats

import eeg
import cocoa_pad
Expand All @@ -22,6 +25,24 @@ def encode_image_to_base64(image_path):
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_string

def compute_confidence_interval(data, confidence=0.95):
"""
Compute the confidence interval for the provided data.
Parameters:
- data (np.array): Array of data points.
- confidence (float): Confidence level (default is 0.95).
Returns:
- ci_lower (np.array): Lower bound of the confidence interval.
- ci_upper (np.array): Upper bound of the confidence interval.
"""
n = data.shape[0]
m = np.mean(data, axis=0)
se = stats.sem(data, axis=0)
h = se * stats.t.ppf((1 + confidence) / 2., n - 1)
return m - h, m + h

# TODO: handle multiple files
@app.route('/api/v1/process_eeg', methods=['POST'])
def process_eeg():
Expand Down Expand Up @@ -94,11 +115,131 @@ def process_eeg():
print("error", e)
return jsonify({'error': str(e)}), 500

# TODO: endpoint for ERP analysis
@app.route('/api/v1/process_eeg_erp', methods=['POST'])
def process_eeg_erp():
def validate_eeg_file(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if 'eegFile' not in request.files:
return jsonify({'error': 'No EEG file submitted for processing'}), 400
return func(*args, **kwargs)
return wrapper

@app.route('/api/v1/process_visual_oddball', methods=['POST'])
@validate_eeg_file
def process_visual_oddball():
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import mne
import json
import io
try:
return jsonify({'response': "works perfect"}), 200
eegFile = request.files['eegFile']
samplingFrequency = int(request.form['samplingFrequency'])
stimulusFile = request.files['stimulusFile']

# assumptions, eeg file is csv.. situmuls file is .json
eeg_df = pd.read_csv(eegFile)
eeg_df.drop(columns=['index'], inplace=True)
sfreq = samplingFrequency

# Find periods of interest that overlap both eeg and stimulus
eeg_timestamps = eeg_df['unixTimestamp'].tolist()
eeg_timestamps_range = (min(eeg_timestamps), max(eeg_timestamps))
stimulus_json = json.loads(stimulusFile.read().decode('utf-8'))
filtered_json_events = [trial for trial in stimulus_json['trials'] if 'unixTimestamp' in trial and
eeg_timestamps_range[0] <= trial['unixTimestamp'] <= eeg_timestamps_range[1]]
if len(filtered_json_events) == 0:
raise ValueError("No valid events found after filtering with the CSV timestamps range.")

# Create MNE events array from filtered JSON events
event_id = {'standard': 1, 'oddball': 2}
events = []
start_time = eeg_timestamps[0] / 1e3 # Convert to seconds #TODO: check if it's in milliseconds first
for trial in filtered_json_events:
if 'value' in trial:
if 'oddball' in trial['value']:
event_type = event_id['oddball']
elif 'standard' in trial['value']:
event_type = event_id['standard']
else:
continue
event_time = trial['unixTimestamp'] / 1e3 # Convert to seconds
event_sample = int((event_time - start_time) * sfreq)
events.append([event_sample, 0, event_type])

events = np.array(events)
if len(events) == 0:
raise ValueError("No valid events found for creating epochs.")

# Create MNE Raw object
info = mne.create_info(ch_names=list(eeg_df.columns[1:]), sfreq=sfreq, ch_types='eeg')
eeg_df = eeg_df.values[:, 1:].T
eeg_df *= 1e-6 # convert from uV to V
raw = mne.io.RawArray(eeg_df, info)
raw.set_montage('standard_1020')

lfreq = 1
ufreq = 40

# Filter the data
raw.filter(lfreq, ufreq, fir_design='firwin')

# Create epochs
epochs = mne.Epochs(raw, events, event_id, tmin=-0.2, tmax=0.8, baseline=(None, 0), preload=True)

# Compute the average ERP for each condition
evoked_standard = epochs['standard'].average()
evoked_oddball = epochs['oddball'].average()

print("ERP data computed.")

# Plot the ERP for the standard and oddball stimulus
print("Plotting ERP for standard and oddball stimulus...")
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('ERP for Standard and Oddball Stimulus')

for idx, ch_name in enumerate(evoked_standard.ch_names):
ax = axes[idx // 2, idx % 2]

# Plot standard ERP
ax.plot(evoked_standard.times, evoked_standard.data[idx],
label='Standard', color='blue')

# Plot oddball ERP
ax.plot(evoked_oddball.times, evoked_oddball.data[idx],
label='Oddball', color='red')

# Calculate and plot the difference waveform (Oddball - Standard)
# difference_wave = evoked_oddball.data[idx] - evoked_standard.data[idx]
# ax.plot(evoked_standard.times, difference_wave,
# label='Difference (Oddball - Standard)', color='green')

ax.axvline(0, color='k', linestyle='--', label='Stimulus Onset')
ax.set_title(f'Channel: {ch_name}')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude (µV)')
ax.legend()
ax.grid(True)

plt.tight_layout()

# Save the plot to a BytesIO object
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png')
img_buffer.seek(0)

# Encode the image to base64
img_str = base64.b64encode(img_buffer.getvalue()).decode()

# Create a dictionary with the image data
erp_plot = {
"key": "ERP Standard vs Oddball (All Channels)",
"value": f"data:image/png;base64,{img_str}",
"summary": "ERP plot for standard and oddball stimulus across all channels.\n Filtered between 1 and 40 Hz. Using MNE firwin"
}
plt.close()

return jsonify({'images': [erp_plot], 'summary': "ERP Standard vs Oddball (All Channels)"}), 200
except Exception as e:
return jsonify({'error': 'error processing', 'message': e}), 500

Expand All @@ -124,8 +265,6 @@ def process_eeg_fooof():
eegFile = request.files['eegFile']
samplingFrequency = int(request.form['samplingFrequency'])

print("eegFile", eegFile)

# Check if the file has a filename
if eegFile.filename == '':
return jsonify({'error': 'No selected EEG file'}), 400
Expand Down
7 changes: 5 additions & 2 deletions frontend/src/pages/analysis.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ const AnalysisPage: NextPage = () => {

try {
setLoading(true);
const response = await fetch(`${process.env["NEXT_PUBLIC_ANALYSIS_SERVER_URL"]}/api/v1/process_eeg_fooof`, {
const urlEndpoint = stimulusFile ? "process_visual_oddball" : "process_eeg_fooof";
const response = await fetch(`${process.env["NEXT_PUBLIC_ANALYSIS_SERVER_URL"]}/api/v1/${urlEndpoint}`, {
method: "POST",
body: formData,
});
Expand Down Expand Up @@ -233,9 +234,11 @@ export const getServerSideProps: GetServerSideProps = async ({ req, res }) => {
const session = await getServerSession(req, res, authOptions);

if (!session) {
// login the user
const currentUrl = `${req.url}`;
return {
redirect: {
destination: "/auth/login",
destination: `/auth/login?callbackUrl=${encodeURIComponent(currentUrl)}`,
permanent: false,
},
};
Expand Down
4 changes: 3 additions & 1 deletion frontend/src/pages/recordings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ export const getServerSideProps: GetServerSideProps = async ({ req, res }) => {
const session = await getServerSession(req, res, authOptions);

if (!session) {
// login the user
const currentUrl = `${req.url}`;
return {
redirect: {
destination: "/auth/login",
destination: `/auth/login?callbackUrl=${encodeURIComponent(currentUrl)}`,
permanent: false,
},
};
Expand Down

0 comments on commit 80b7305

Please sign in to comment.