Skip to content

Commit

Permalink
V0 migration model. Renaming for dropped, no surgery
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Nov 29, 2023
1 parent 04ee070 commit 84da68d
Show file tree
Hide file tree
Showing 12 changed files with 856 additions and 250 deletions.
80 changes: 76 additions & 4 deletions src/spyglass/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,10 @@
NwbfileKachery,
)
from .common_position import (
IntervalLinearizationSelection,
IntervalLinearizedPosition,
IntervalPositionInfo,
IntervalPositionInfoSelection,
LinearizationParameters,
PositionInfoParameters,
PositionVideo,
TrackGraph,
)
from .common_region import BrainRegion
from .common_sensors import SensorData
Expand All @@ -73,5 +69,81 @@
from .populate_all_common import populate_all_common
from .prepopulate import populate_from_yaml, prepopulate_default

from spyglass.linearization.v0 import ( # isort:skip
IntervalLinearizationSelection,
IntervalLinearizedPosition,
LinearizationParameters,
TrackGraph,
)

__all__ = [
"AnalysisNwbfile",
"AnalysisNwbfileKachery",
"BrainRegion",
"CameraDevice",
"DIOEvents",
"DataAcquisitionDevice",
"DataAcquisitionDeviceAmplifier",
"DataAcquisitionDeviceSystem",
"Electrode",
"ElectrodeGroup",
"FirFilterParameters",
"Institution",
"IntervalLinearizationSelection",
"IntervalLinearizedPosition",
"IntervalList",
"IntervalPositionInfo",
"IntervalPositionInfoSelection",
"LFP",
"LFPBand",
"LFPBandSelection",
"LFPSelection",
"Lab",
"LabMember",
"LabTeam",
"LinearizationParameters",
"Nwbfile",
"NwbfileKachery",
"PositionInfoParameters",
"PositionIntervalMap",
"PositionSource",
"PositionVideo",
"Probe",
"ProbeType",
"Raw",
"RawPosition",
"SampleCount",
"SensorData",
"Session",
"SessionGroup",
"StateScriptFile",
"Subject",
"Task",
"TaskEpoch",
"TrackGraph",
"VideoFile",
"close_nwb_files",
"convert_epoch_interval_name_to_position_interval_name",
"estimate_sampling_rate",
"get_data_interface",
"get_electrode_indices",
"get_nwb_file",
"get_raw_eseries",
"get_valid_intervals",
"interval_list_censor",
"interval_list_contains",
"interval_list_contains_ind",
"interval_list_excludes",
"interval_list_excludes_ind",
"interval_list_intersect",
"interval_list_union",
"intervals_by_length",
"os",
"populate_all_common",
"populate_from_yaml",
"prepopulate_default",
"sg",
]

if sg.config["prepopulate"]:
prepopulate_default()
213 changes: 26 additions & 187 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,11 @@
)
from position_tools.core import gaussian_smooth
from tqdm import tqdm_notebook as tqdm
from track_linearization import (
get_linearized_position,
make_track_graph,
plot_graph_as_1D,
plot_track_graph,
)

from ..settings import raw_dir, video_dir
from ..utils.dj_helper_fn import fetch_nwb
from .common_behav import RawPosition, VideoFile
from .common_interval import IntervalList # noqa F401
from .common_nwbfile import AnalysisNwbfile
from spyglass.common.common_behav import RawPosition, VideoFile
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.settings import raw_dir, video_dir
from spyglass.utils.dj_helper_fn import deprecated_factory, fetch_nwb

schema = dj.schema("common_position")

Expand Down Expand Up @@ -503,184 +496,30 @@ def _data_to_df(data, prefix="head_", add_frame_ind=False):
return df


@schema
class LinearizationParameters(dj.Lookup):
"""Choose whether to use an HMM to linearize position.
This can help when the euclidean distances between separate arms are too
close and the previous position has some information about which arm the
animal is on.
route_euclidean_distance_scaling: How much to prefer route distances between
successive time points that are closer to the euclidean distance. Smaller
numbers mean the route distance is more likely to be close to the euclidean
distance.
"""

definition = """
linearization_param_name : varchar(80) # name for this set of parameters
---
use_hmm = 0 : int # use HMM to determine linearization
route_euclidean_distance_scaling = 1.0 : float # Preference for euclidean.
sensor_std_dev = 5.0 : float # Uncertainty of position sensor (in cm).
# Biases the transition matrix to prefer the current track segment.
diagonal_bias = 0.5 : float
"""


@schema
class TrackGraph(dj.Manual):
"""Graph representation of track representing the spatial environment.
Used for linearizing position.
"""

definition = """
track_graph_name : varchar(80)
----
environment : varchar(80) # Type of Environment
node_positions : blob # 2D position of nodes, (n_nodes, 2)
edges: blob # shape (n_edges, 2)
linear_edge_order : blob # order of edges in linear space, (n_edges, 2)
linear_edge_spacing : blob # space btwn edges in linear space, (n_edges,)
"""

def get_networkx_track_graph(self, track_graph_parameters=None):
if track_graph_parameters is None:
track_graph_parameters = self.fetch1()
return make_track_graph(
node_positions=track_graph_parameters["node_positions"],
edges=track_graph_parameters["edges"],
)

def plot_track_graph(self, ax=None, draw_edge_labels=False, **kwds):
"""Plot the track graph in 2D position space."""
track_graph = self.get_networkx_track_graph()
plot_track_graph(
track_graph, ax=ax, draw_edge_labels=draw_edge_labels, **kwds
)

def plot_track_graph_as_1D(
self,
ax=None,
axis="x",
other_axis_start=0.0,
draw_edge_labels=False,
node_size=300,
node_color="#1f77b4",
):
"""Plot the track graph in 1D to see how the linearization is set up."""
track_graph_parameters = self.fetch1()
track_graph = self.get_networkx_track_graph(
track_graph_parameters=track_graph_parameters
)
plot_graph_as_1D(
track_graph,
edge_order=track_graph_parameters["linear_edge_order"],
edge_spacing=track_graph_parameters["linear_edge_spacing"],
ax=ax,
axis=axis,
other_axis_start=other_axis_start,
draw_edge_labels=draw_edge_labels,
node_size=node_size,
node_color=node_color,
)


@schema
class IntervalLinearizationSelection(dj.Lookup):
definition = """
-> IntervalPositionInfo
-> TrackGraph
-> LinearizationParameters
---
"""


@schema
class IntervalLinearizedPosition(dj.Computed):
"""Linearized position for a given interval"""

definition = """
-> IntervalLinearizationSelection
---
-> AnalysisNwbfile
linearized_position_object_id : varchar(40)
"""

def make(self, key):
print(f"Computing linear position for: {key}")

key["analysis_file_name"] = AnalysisNwbfile().create(
key["nwb_file_name"]
)

position_nwb = (
IntervalPositionInfo
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["interval_list_name"],
"position_info_param_name": key["position_info_param_name"],
}
).fetch_nwb()[0]

position = np.asarray(
position_nwb["head_position"].get_spatial_series().data
)
time = np.asarray(
position_nwb["head_position"].get_spatial_series().timestamps
)

linearization_parameters = (
LinearizationParameters()
& {"linearization_param_name": key["linearization_param_name"]}
).fetch1()
track_graph_info = (
TrackGraph() & {"track_graph_name": key["track_graph_name"]}
).fetch1()
# ------------------------------ Migrated Tables ------------------------------

track_graph = make_track_graph(
node_positions=track_graph_info["node_positions"],
edges=track_graph_info["edges"],
)
from spyglass.linearization.v0 import main as linV0 # noqa: E402

linear_position_df = get_linearized_position(
position=position,
track_graph=track_graph,
edge_spacing=track_graph_info["linear_edge_spacing"],
edge_order=track_graph_info["linear_edge_order"],
use_HMM=linearization_parameters["use_hmm"],
route_euclidean_distance_scaling=linearization_parameters[
"route_euclidean_distance_scaling"
],
sensor_std_dev=linearization_parameters["sensor_std_dev"],
diagonal_bias=linearization_parameters["diagonal_bias"],
)

linear_position_df["time"] = time

# Insert into analysis nwb file
nwb_analysis_file = AnalysisNwbfile()

key["linearized_position_object_id"] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=linear_position_df,
)

nwb_analysis_file.add(
nwb_file_name=key["nwb_file_name"],
analysis_file_name=key["analysis_file_name"],
)

self.insert1(key)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs
)

def fetch1_dataframe(self):
return self.fetch_nwb()[0]["linearized_position"].set_index("time")
(
LinearizationParameters,
TrackGraph,
IntervalLinearizationSelection,
IntervalLinearizedPosition,
) = deprecated_factory(
[
("LinearizationParameters", linV0.LinearizationParameters),
("TrackGraph", linV0.TrackGraph),
(
"IntervalLinearizationSelection",
linV0.IntervalLinearizationSelection,
),
(
"IntervalLinearizedPosition",
linV0.IntervalLinearizedPosition,
),
],
old_module=__name__,
)


class NodePicker:
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/linearization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# from spyglass.linearization.merge import LinearizedOutput
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
import datajoint as dj

from spyglass.position_linearization.v1.linearization import ( # noqa F401
LinearizedPositionV1,
)
from spyglass.linearization.v1.main import LinearizedV1 # noqa F401

from ..utils.dj_merge_tables import _Merge

schema = dj.schema("position_linearization_merge")
schema = dj.schema("linearization_merge")


@schema
class LinearizedPositionOutput(_Merge):
class LinearizedOutput(_Merge):
definition = """
merge_id: uuid
---
source: varchar(32)
"""

class LinearizedPositionV1(dj.Part): # noqa: F811
class LinearizedV1(dj.Part): # noqa F811
definition = """
-> LinearizedPositionOutput
-> master
---
-> LinearizedPositionV1
-> LinearizedV1
"""

def fetch1_dataframe(self):
Expand Down
6 changes: 6 additions & 0 deletions src/spyglass/linearization/v0/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .main import (
IntervalLinearizationSelection,
IntervalLinearizedPosition,
LinearizationParameters,
TrackGraph,
)
Loading

0 comments on commit 84da68d

Please sign in to comment.