Skip to content

Commit

Permalink
Ensure correct collection numbers are propagated to detectors
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Aug 1, 2023
1 parent f14cca5 commit 842ac51
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 53 deletions.
89 changes: 61 additions & 28 deletions src/blueapi/plugins/data_writing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, abstractproperty
from collections import deque
from pathlib import Path
from typing import (
Expand All @@ -18,6 +18,7 @@

import bluesky.plan_stubs as bps
import requests
from aiohttp import ClientSession
from bluesky.protocols import Movable
from bluesky.utils import Msg, make_decorator
from ophyd.areadetector.filestore_mixins import FileStoreBase
Expand All @@ -31,65 +32,97 @@


class DataCollectionProvider(ABC):
@abstractproperty
def current_data_collection(self) -> Optional[DataCollection]:
...

@abstractmethod
def get_next_data_collection(self, collection_group: str) -> DataCollection:
async def update(self) -> None:
...


class ServiceDataCollectionProvider(DataCollectionProvider):
def get_next_data_collection(self, collection_group: str) -> DataCollection:
reply = requests.post(f"http://localhost:8089/collection/{collection_group}")
result = DataCollectionSetupResult.parse_obj(reply.json())
_collection_group: str
_current_collection: Optional[DataCollection]

def __init__(self, collection_group: str) -> None:
self._collection_group = collection_group
self._current_collection = None

@property
def current_data_collection(self) -> Optional[DataCollection]:
return self._current_collection

async def update(self) -> None:
async with ClientSession() as session:
async with session.post(
f"http://localhost:8089/collection/{self._collection_group}"
) as response:
if response.status == 200:
json = await response.json()
result = DataCollectionSetupResult.parse_obj(json)
else:
raise Exception(response.status)
if result.directories_created:
return result.collection
result.collection
else:
raise Exception()


class InMemoryDataCollectionProvider(DataCollectionProvider):
_collection_group: str
_scan_number: itertools.count
_current_collection: Optional[DataCollection]

def __init__(self) -> None:
def __init__(self, collection_group: str) -> None:
self._collection_group = collection_group
self._scan_number = itertools.count()
self._current_collection = None

def get_next_data_collection(self, collection_group: str) -> DataCollection:
@property
def current_data_collection(self) -> Optional[DataCollection]:
return self._current_collection

async def update(self) -> None:
scan_number = next(self._scan_number)
return DataCollection(
self._current_collection = DataCollection(
collection_number=scan_number,
group=collection_group,
raw_data_files_root=Path(f"/tmp/{collection_group}"),
nexus_file_path=Path(f"/tmp{collection_group}.nxs"),
group=self._collection_group,
raw_data_files_root=Path(f"/tmp/{self._collection_group}"),
nexus_file_path=Path(f"/tmp{self._collection_group}.nxs"),
)


def data_writing_wrapper(
plan: MsgGenerator,
collection_group: str,
provider: Optional[DataCollectionProvider] = None,
provider: DataCollectionProvider,
) -> MsgGenerator:
if provider is None:
provider = InMemoryDataCollectionProvider()

scan_number = itertools.count()
next_scan_number = None
stage_stack: Deque = deque()
# scan_number_stack: Deque = deque()
for message in plan:
if message.command == "stage":
if not stage_stack:
yield from bps.wait_for([provider.update])
if provider.current_data_collection is None:
raise Exception("There is no active data collection")
stage_stack.append(message.obj)
all_devices = walk_devices([message.obj])
configure_data_writing(
all_devices,
provider.current_data_collection,
)
elif stage_stack:
next_scan_number = next(scan_number)
root_devices = []
while stage_stack:
root_devices.append(stage_stack.pop())
all_devices = walk_devices(root_devices)
collection = provider.get_next_data_collection(collection_group)
configure_data_writing(all_devices, collection)
stage_stack.pop()

if message.command == "open_run":
if next_scan_number is None:
next_scan_number = next(scan_number)
message.kwargs[DATA_COLLECTION_NUMBER] = next_scan_number
if provider.current_data_collection is None:
yield from bps.wait_for([provider.update])
if provider.current_data_collection is None:
raise Exception("There is no active data collection")
message.kwargs[
DATA_COLLECTION_NUMBER
] = provider.current_data_collection.collection_number
yield message


Expand Down
31 changes: 31 additions & 0 deletions tests/plugins/file_writing_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import itertools
import uuid
from pathlib import Path
from typing import List, Optional, Tuple

import h5py as h5
import numpy as np
from ophyd import Component, Device, Signal, SignalRO
from ophyd.sim import EnumSignal, SynGauss, SynSignal, SynSignalRO

from blueapi.plugins.data_writing import DataCollectionProvider


class FakeFileWritingDetector(Device):
image_count: Signal = Component(Signal, value=0, kind="hinted")
collection_number: Signal = Component(Signal, value=0, kind="config")

_provider: DataCollectionProvider

def __init__(self, name: str, provider: DataCollectionProvider, **kwargs):
super().__init__(name=name, **kwargs)
self.stage_sigs[self.image_count] = 0
self._provider = provider

def trigger(self, *args, **kwargs):
return self.image_count.set(self.image_count.get() + 1)

def stage(self) -> List[object]:
collection = self._provider.current_data_collection
self.stage_sigs[self.collection_number] = collection.collection_number
return super().stage()
Loading

0 comments on commit 842ac51

Please sign in to comment.