Skip to content

Commit

Permalink
ENH: add 2d kmeans agent
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Dec 13, 2023
1 parent f9ad0b7 commit b7dac0b
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 13 deletions.
10 changes: 8 additions & 2 deletions bmm_agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
exp_bounds: str = "-200 -30 -10 25 12k",
exp_steps: str = "10 2 0.3 0.05k",
exp_times: str = "0.5 0.5 0.5 0.5",
variable_motor_names: List[str] = ["xafs_x"],
**kwargs,
):
self._filename = filename
Expand All @@ -51,6 +52,7 @@ def __init__(
self._exp_bounds = exp_bounds
self._exp_steps = exp_steps
self._exp_times = exp_times
self._variable_motor_names = variable_motor_names

_default_kwargs = self.get_beamline_objects()
_default_kwargs.update(kwargs)
Expand Down Expand Up @@ -199,11 +201,15 @@ def unpack_run(self, run):
idx_min = np.where(ordinate < self.roi[0])[0][-1] if len(np.where(ordinate < self.roi[0])[0]) else None
idx_max = np.where(ordinate > self.roi[1])[0][-1] if len(np.where(ordinate > self.roi[1])[0]) else None
y = y[idx_min:idx_max]
return run.baseline.data["xafs_x"][0], y
return np.array([run.baseline.data[key][0] for key in self._variable_motor_names]), y

def measurement_plan(self, relative_point: ArrayLike) -> Tuple[str, List, dict]:
"""Works from relative points"""
element_positions = self.element_origins + relative_point
if len(relative_point) == 2:
element_positions = self.element_origins + relative_point
else:
element_positions = self.element_origins
element_positions[0] += relative_point
args = [
self.sample_position_motors[0],
*element_positions[:, 0],
Expand Down
41 changes: 31 additions & 10 deletions bmm_agents/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from numpy.typing import ArrayLike
from scipy.stats import rv_discrete
from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression

from .base import BMMBaseAgent
from .utils import discretize, make_hashable
from .utils import discretize, make_hashable, make_wafer_grid_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,7 +122,16 @@ def _sample_uncertainty_proxy(self, batch_size=1):
"""
# Borrowing from Dan's jupyter fun
# from measurements, perform k-means
sorted_independents, sorted_observables = zip(*sorted(zip(self.independent_cache, self.observable_cache)))
try:
sorted_independents, sorted_observables = zip(
*sorted(zip(self.independent_cache, self.observable_cache))
)
except ValueError:
# Multidimensional case
sorted_independents, sorted_observables = zip(
*sorted(zip(self.independent_cache, self.observable_cache), key=lambda x: (x[0][0], x[0][1]))
)

sorted_independents = np.array(sorted_independents)
sorted_observables = np.array(sorted_observables)
self.model.fit(sorted_observables)
Expand All @@ -131,11 +141,19 @@ def _sample_uncertainty_proxy(self, batch_size=1):
distances = self.model.transform(sorted_observables)
# determine golf-score of each point (minimum value)
min_landscape = distances.min(axis=1)
# generate 'uncertainty weights' - as a polynomial fit of the golf-score for each point
_x = np.arange(*self.bounds, self.min_step_size)
uwx = polyval(_x, polyfit(sorted_independents, min_landscape, deg=5))
# Chose from the polynomial fit
return pick_from_distribution(_x, uwx, num_picks=batch_size), centers
if self.bounds.size == 2:
# Assume a 1d scan
# generate 'uncertainty weights' - as a polynomial fit of the golf-score for each point
_x = np.arange(*self.bounds, self.min_step_size)
uwx = polyval(_x, polyfit(sorted_independents, min_landscape, deg=5))
# Chose from the polynomial fit
return pick_from_distribution(_x, uwx, num_picks=batch_size), centers
else:
# assume a 2d scan, use a linear model to predict the uncertainty
grid = make_wafer_grid_list(*self.bounds.ravel(), step=self.min_step_size)
uncertainty_preds = LinearRegression().fit(sorted_independents, min_landscape).predict(grid)
top_indicies = np.argsort(uncertainty_preds)[-batch_size:]
return grid[top_indicies], centers

def ask(self, batch_size=1):
suggestions, centers = self._sample_uncertainty_proxy(batch_size)
Expand All @@ -144,11 +162,14 @@ def ask(self, batch_size=1):
suggestions = [suggestions]
# Keep non redundant suggestions and add to knowledge cache
for suggestion in suggestions:
if suggestion in self.knowledge_cache:
logger.info(f"Suggestion {suggestion} is ignored as already in the knowledge cache")
hashable_suggestion = make_hashable(discretize(suggestion, self.min_step_size))
if hashable_suggestion in self.knowledge_cache:
logger.info(
f"Suggestion {suggestion} is ignored as already in the knowledge cache: {hashable_suggestion}"
)
continue
else:
self.knowledge_cache.add(make_hashable(discretize(suggestion, self.min_step_size)))
self.knowledge_cache.add(hashable_suggestion)
kept_suggestions.append(suggestion)

base_doc = dict(
Expand Down
35 changes: 35 additions & 0 deletions bmm_agents/startup_scripts/historical_mmm4_uids.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
9c1f9ce6-6c35-42c9-ae0b-40620336596c
b02b4bfc-ef9f-4a8d-bfcd-852923061371
df21d22a-dc85-4566-932f-d66b20f714bd
7a5257a9-c5f6-4dab-908e-82ec74308817
818400f0-b9a7-489d-ba3b-9839d8e35700
3c8a6931-b465-46e2-860a-bca6a203ac04
4bbd010d-5310-4fad-ad4e-956517c04aff
7d7e5497-9fb4-4829-b526-425d42974012
8e4e5b73-2fad-41ca-b3ca-e1c1d6183052
797a514b-f673-4131-a236-f5250331f3dd
434b6f94-37ae-41d3-8d3e-8b4ab18d9711
ea441617-9794-46f0-8e6e-f704ebba8163
e8374ec8-2a80-48c4-a77f-bd3271677590
4a992e79-3f45-4c1c-8a99-a47f7b8d8af5
cb0629dc-a6ea-4581-abbc-bbde76aecb10
7fd0e59f-9b06-48f1-a17d-a9053032fe34
ef501a87-5e09-41aa-a72b-20004e00d510
dafdf68f-a064-4dd3-acf0-dd6506c0aca7
1ba7768a-bddb-48ac-9148-1162659c38d0
60d42219-ab88-4865-ae44-6684e538c322
6d1be8c4-2534-4e8b-a12e-82875eae3996
adb51916-d093-44d0-b86c-6397901d4eec
340e4116-2a30-4a4c-a1a4-04ca7c7657e0
91ce30b3-03cf-4557-b7a5-c97293dce1be
ec5023d6-a45d-4109-8d42-7cc0c74d72ed
2bd7ca7f-4ac4-4ed8-8eac-a5e8aa0aea89
bd09a4ee-3e36-4f07-b1f8-e9baace2617a
bd1a9f03-8117-4393-a49f-9ea2c28b53c1
95364f08-e085-41ad-9b23-6820a850c67f
5c2d9d83-89e2-481e-818d-73864b792ed6
58b676df-1f08-4554-8736-ac6ef1fe0422
0288139d-b373-4525-ad51-f919b4eb5d1a
adb8d6a8-6b7d-4d38-8c74-26fdc3268519
803bc7ef-60ba-4c43-962c-6db72b400f6d
4a9fc081-fd94-45c8-a2e4-e7cfd615d155
3 changes: 2 additions & 1 deletion bmm_agents/startup_scripts/mmm4_Pt_kmeans.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from bluesky_adaptive.server import register_variable, shutdown_decorator, startup_decorator

from bmm_agents.sklearn import MultiElementActiveKmeansAgent
Expand All @@ -16,7 +17,7 @@
exp_bounds="-200 -30 -10 25 13k",
exp_steps="10 2 0.5 0.05k",
exp_times="1 1 1 1",
bounds=(-32, 32),
bounds=np.array([(-32, 32), (-32, 32)]),
ask_on_tell=False,
report_on_tell=True,
k_clusters=6,
Expand Down
13 changes: 13 additions & 0 deletions bmm_agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ def make_hashable(x):
return float(x)


def make_wafer_grid_list(x_min, x_max, y_min, y_max, step):
"""
Make the list of all of the possible 2d points that lie within a circle of the origin
"""
x = np.arange(x_min, x_max, step)
y = np.arange(y_min, y_max, step)
xx, yy = np.meshgrid(x, y)
center = np.array([x_min + (x_max - x_min) / 2, y_min + (y_max - y_min) / 2])
distance = np.sqrt((xx - center[0]) ** 2 + (yy - center[1]) ** 2)
radius = min((x_max - x_min) / 2, (y_max - y_min) / 2)
return np.array([xx[distance < radius], yy[distance < radius]]).T


class Pandrosus:
"""A thin wrapper around basic XAS data processing for individual
data sets as implemented in Larch.
Expand Down

0 comments on commit b7dac0b

Please sign in to comment.