Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experiment type #702

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/702.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `Experiment.get_type()` to replace `Experiment.is_still()/Experiment.is_sequence()`
2 changes: 2 additions & 0 deletions src/dxtbx/dxtbx_model_ext.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from scitbx.array_family import shared as flex_shared
from scitbx.array_family.flex import FlexPlain

from dxtbx_model_ext import Probe # type: ignore
from dxtbx_model_ext import ExperimentType

# TypeVar for the set of Experiment models that can be joint-accepted
# - profile, imageset and scalingmodel are handled as 'object'
Expand Down Expand Up @@ -354,6 +355,7 @@ class Experiment:
def is_sequence(self) -> bool: ...
def is_still(self) -> bool: ...
def __contains__(self, obj: TExperimentModel) -> bool: ...
def get_type(self) -> ExperimentType: ...

class ExperimentList:
@overload
Expand Down
14 changes: 12 additions & 2 deletions src/dxtbx/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DetectorNode,
Experiment,
ExperimentList,
ExperimentType,
Goniometer,
GoniometerBase,
KappaDirection,
Expand Down Expand Up @@ -69,6 +70,7 @@
DetectorNode,
Experiment,
ExperimentList,
ExperimentType,
Goniometer,
GoniometerBase,
KappaDirection,
Expand Down Expand Up @@ -599,11 +601,19 @@

def all_stills(self):
"""Check if all the experiments are stills"""
return all(exp.is_still() for exp in self)
return all(exp.get_type() == ExperimentType.STILL for exp in self)

def all_sequences(self):
"""Check if all the experiments are from sequences"""
return all(exp.is_sequence() for exp in self)
return self.all_rotations()

Check warning on line 608 in src/dxtbx/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/dxtbx/model/__init__.py#L608

Added line #L608 was not covered by tests

def all_rotations(self):
"""Check if all the experiments are stills"""
return all(exp.get_type() == ExperimentType.ROTATION for exp in self)

def all_tof(self):
"""Check if all the experiments are time-of-flight"""
return all(exp.get_type() == ExperimentType.TOF for exp in self)

def to_dict(self):
"""Serialize the experiment list to dictionary."""
Expand Down
6 changes: 6 additions & 0 deletions src/dxtbx/model/boost_python/experiment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ namespace dxtbx { namespace model { namespace boost_python {
};

void export_experiment() {
enum_<ExperimentType>("ExperimentType")
.value("STILL", STILL)
.value("ROTATION", ROTATION)
.value("TOF", TOF);

class_<Experiment>("Experiment")
.def(init<std::shared_ptr<BeamBase>,
std::shared_ptr<Detector>,
Expand Down Expand Up @@ -118,6 +123,7 @@ namespace dxtbx { namespace model { namespace boost_python {
.def("is_sequence",
&Experiment::is_sequence,
"Check if this experiment represents swept rotation image(s)")
.def("get_type", &Experiment::get_type)
.def_pickle(ExperimentPickleSuite());
}

Expand Down
83 changes: 18 additions & 65 deletions src/dxtbx/model/experiment.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

namespace dxtbx { namespace model {

enum ExperimentType { ROTATION = 1, STILL = 2, TOF = 3 };

/**
* A class to represent what's in an experiment.
*
Expand Down Expand Up @@ -128,9 +130,6 @@ namespace dxtbx { namespace model {
return profile_ == obj || imageset_ == obj || scaling_model_ == obj;
}

/**
* Compare this experiment with another
*/
bool operator==(const Experiment &other) const {
return imageset_ == other.imageset_ && beam_ == other.beam_
&& detector_ == other.detector_ && goniometer_ == other.goniometer_
Expand All @@ -139,18 +138,11 @@ namespace dxtbx { namespace model {
&& identifier_ == other.identifier_;
}

/**
* Check that the experiment is consistent
*/
bool is_consistent() const {
return true; // FIXME
}

/**
* Check if this experiment represents a still image
*/
bool is_still() const {
return !goniometer_ || !scan_ || scan_->is_still();
return get_type() == STILL;
}

/**
Expand All @@ -160,128 +152,89 @@ namespace dxtbx { namespace model {
return !is_still();
}

/**
* Set the beam model
*/
ExperimentType get_type() const {
if (scan_ && scan_->contains("time_of_flight")) {
return TOF;
}
if (!goniometer_ || !scan_ || scan_->is_still()) {
return STILL;
} else {
return ROTATION;
}
}

bool is_consistent() const {
return true; // FIXME
}

void set_beam(std::shared_ptr<BeamBase> beam) {
beam_ = beam;
}

/**
* Get the beam model
*/
std::shared_ptr<BeamBase> get_beam() const {
return beam_;
}

/**
* Get the detector model
*/
void set_detector(std::shared_ptr<Detector> detector) {
detector_ = detector;
}

/**
* Get the detector model
*/
std::shared_ptr<Detector> get_detector() const {
return detector_;
}

/**
* Get the goniometer model
*/
void set_goniometer(std::shared_ptr<Goniometer> goniometer) {
goniometer_ = goniometer;
}

/**
* Get the goniometer model
*/
std::shared_ptr<Goniometer> get_goniometer() const {
return goniometer_;
}

/**
* Get the scan model
*/
void set_scan(std::shared_ptr<Scan> scan) {
scan_ = scan;
}

/**
* Get the scan model
*/
std::shared_ptr<Scan> get_scan() const {
return scan_;
}

/**
* Get the crystal model
*/
void set_crystal(std::shared_ptr<CrystalBase> crystal) {
crystal_ = crystal;
}

/**
* Get the crystal model
*/
std::shared_ptr<CrystalBase> get_crystal() const {
return crystal_;
}

/**
* Get the profile model
*/
void set_profile(boost::python::object profile) {
profile_ = profile;
}

/**
* Get the profile model
*/
boost::python::object get_profile() const {
return profile_;
}

/**
* Get the imageset model
*/
void set_imageset(boost::python::object imageset) {
imageset_ = imageset;
}

/**
* Get the imageset model
*/
boost::python::object get_imageset() const {
return imageset_;
}

/**
* Set the scaling model
*/
void set_scaling_model(boost::python::object scaling_model) {
scaling_model_ = scaling_model;
}

/**
* Get the scaling model
*/
boost::python::object get_scaling_model() const {
return scaling_model_;
}

/**
* Set the identifier
*/
void set_identifier(std::string identifier) {
identifier_ = identifier;
}

/**
* Get the identifier
*/
std::string get_identifier() const {
return identifier_;
}
Expand Down
17 changes: 9 additions & 8 deletions tests/model/test_experiment_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Detector,
Experiment,
ExperimentList,
ExperimentType,
Goniometer,
Scan,
ScanFactory,
Expand Down Expand Up @@ -773,26 +774,26 @@ def test_partial_missing_model_serialization():
check(elist, elist_)


def test_experiment_is_still():
def test_experiment_type():
experiment = Experiment()
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL
experiment.goniometer = Goniometer()
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL
experiment.scan = Scan()
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL
experiment.scan = Scan((1, 1000), (0, 0.05))
assert not experiment.is_still()
assert experiment.get_type() == ExperimentType.ROTATION
# Specifically test the bug from dxtbx#4 triggered by ending on 0°
experiment.scan = Scan((1, 1800), (-90, 0.05))
assert not experiment.is_still()
assert experiment.get_type() == ExperimentType.ROTATION
experiment.scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"time_of_flight": list(range(10))}
)
assert not experiment.is_still()
assert experiment.get_type() == ExperimentType.TOF
experiment.scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"other_property": list(range(10))}
)
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL


def check(el1, el2):
Expand Down
Loading