argmin_observer_spectator/observer.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::{collections::HashSet, thread::JoinHandle};
9
10use anyhow::Error;
11use argmin::core::{
12 observers::Observe, ArgminFloat, State, TerminationReason, TerminationStatus, KV,
13};
14use spectator::{Message, DEFAULT_PORT};
15use time::Duration;
16use uuid::Uuid;
17
18use crate::sender::sender;
19
20const DEFAULT_HOST: &str = "127.0.0.1";
21
22/// Builder for the Spectator observer
23///
24/// # Example
25///
26/// ```
27/// use argmin_observer_spectator::SpectatorBuilder;
28///
29/// let spectator = SpectatorBuilder::new()
30/// // Optional: Name the optimization run
31/// // Default: random UUID.
32/// .with_name("optimization_run_1")
33/// // Optional, defaults to 127.0.0.1
34/// .with_host("127.0.0.1")
35/// // Optional, defaults to 5498
36/// .with_port(5498)
37/// // Choose which metrics should automatically be selected.
38/// // If omitted, all metrics will be selected.
39/// .select(&["cost", "best_cost"])
40/// // Build Spectator observer
41/// .build();
42/// ```
43pub struct SpectatorBuilder {
44 name: String,
45 selected: HashSet<String>,
46 capacity: usize,
47 host: String,
48 port: u16,
49}
50
51impl Default for SpectatorBuilder {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl SpectatorBuilder {
58 /// Creates a new `SpectatorBuilder`
59 pub fn new() -> Self {
60 SpectatorBuilder {
61 name: Uuid::new_v4().to_string(),
62 selected: HashSet::new(),
63 capacity: 10_000,
64 host: DEFAULT_HOST.to_string(),
65 port: DEFAULT_PORT,
66 }
67 }
68
69 /// Set a name the optimization run will be identified with
70 ///
71 /// Defaults to a random UUID.
72 ///
73 /// # Example
74 ///
75 /// ```
76 /// # use argmin_observer_spectator::SpectatorBuilder;
77 /// let builder = SpectatorBuilder::new().with_name("optimization_run_1");
78 /// # assert_eq!(builder.name().clone(), "optimization_run_1".to_string());
79 /// ```
80 pub fn with_name<T: AsRef<str>>(mut self, name: T) -> Self {
81 self.name = name.as_ref().to_string();
82 self
83 }
84
85 /// Set the host argmin spectator is running on.
86 ///
87 /// Defaults to 127.0.0.1.
88 ///
89 /// # Example
90 ///
91 /// ```
92 /// # use argmin_observer_spectator::SpectatorBuilder;
93 /// let builder = SpectatorBuilder::new().with_host("192.168.0.1");
94 /// # assert_eq!(builder.host().clone(), "192.168.0.1".to_string());
95 /// ```
96 pub fn with_host<T: AsRef<str>>(mut self, host: T) -> Self {
97 self.host = host.as_ref().to_string();
98 self
99 }
100
101 /// Set the port Spectator is running on.
102 ///
103 /// Defaults to 5498.
104 ///
105 /// # Example
106 ///
107 /// ```
108 /// # use argmin_observer_spectator::SpectatorBuilder;
109 /// let builder = SpectatorBuilder::new().with_port(1234);
110 /// # assert_eq!(builder.port(), 1234);
111 /// ```
112 pub fn with_port(mut self, port: u16) -> Self {
113 self.port = port;
114 self
115 }
116
117 /// Set the channel capacity
118 ///
119 /// A channel is used to queue messages for sending to Spectator. If the channel
120 /// capacity is reached backpressure will be applied, effectively blocking the optimization.
121 /// Defaults to 10000. Decrease this value in case memory consumption is too high and increase
122 /// the value in case blocking causes negative effects.
123 ///
124 /// # Example
125 ///
126 /// ```
127 /// # use argmin_observer_spectator::SpectatorBuilder;
128 /// let builder = SpectatorBuilder::new().with_channel_capacity(1000);
129 /// # assert_eq!(builder.channel_capacity(), 1000);
130 /// ```
131 pub fn with_channel_capacity(mut self, capacity: usize) -> Self {
132 self.capacity = capacity;
133 self
134 }
135
136 /// Define which metrics will be selected in Spectator by default
137 ///
138 /// If none are set, all metrics will be selected and shown. Providing zero or more metrics
139 /// via `select` disables all apart from the provided ones. Note that independent of this
140 /// setting, all data will be sent, and metrics can be selected and deselected via the
141 /// Spectator GUI.
142 ///
143 /// # Example
144 ///
145 /// ```
146 /// # use argmin_observer_spectator::SpectatorBuilder;
147 /// # use std::collections::HashSet;
148 /// let builder = SpectatorBuilder::new().select(&["cost", "best_cost"]);
149 /// # assert_eq!(builder.selected(), &HashSet::from(["cost".to_string(), "best_cost".to_string()]));
150 /// ```
151 pub fn select<T: AsRef<str>>(mut self, metrics: &[T]) -> Self {
152 self.selected = metrics.iter().map(|s| s.as_ref().to_string()).collect();
153 self
154 }
155
156 /// Returns the name of the optimization run
157 ///
158 /// # Example
159 ///
160 /// ```
161 /// # use argmin_observer_spectator::SpectatorBuilder;
162 /// # let builder = SpectatorBuilder::new().with_name("test");
163 /// let name = builder.name();
164 /// # assert_eq!(name, &"test".to_string());
165 /// ```
166 pub fn name(&self) -> &String {
167 &self.name
168 }
169
170 /// Returns the host this observer will connect to
171 ///
172 /// # Example
173 ///
174 /// ```
175 /// # use argmin_observer_spectator::SpectatorBuilder;
176 /// # let builder = SpectatorBuilder::new();
177 /// let host = builder.host();
178 /// # assert_eq!(host, &"127.0.0.1".to_string());
179 /// ```
180 pub fn host(&self) -> &String {
181 &self.host
182 }
183
184 /// Returns the port this observer will connect to
185 ///
186 /// # Example
187 ///
188 /// ```
189 /// # use argmin_observer_spectator::SpectatorBuilder;
190 /// # let builder = SpectatorBuilder::new();
191 /// let port = builder.port();
192 /// # assert_eq!(port, 5498);
193 /// ```
194 pub fn port(&self) -> u16 {
195 self.port
196 }
197
198 /// Returns the channel capacity
199 ///
200 /// # Example
201 ///
202 /// ```
203 /// # use argmin_observer_spectator::SpectatorBuilder;
204 /// # let builder = SpectatorBuilder::new();
205 /// let capacity = builder.channel_capacity();
206 /// # assert_eq!(capacity, 10000);
207 /// ```
208 pub fn channel_capacity(&self) -> usize {
209 self.capacity
210 }
211
212 /// Returns the selected metrics
213 ///
214 /// # Example
215 ///
216 /// ```
217 /// # use argmin_observer_spectator::SpectatorBuilder;
218 /// # use std::collections::HashSet;
219 /// # let builder = SpectatorBuilder::new().select(&["cost", "best_cost"]);
220 /// let selected = builder.selected();
221 /// # assert_eq!(selected, &HashSet::from(["cost".to_string(), "best_cost".to_string()]));
222 /// ```
223 pub fn selected(&self) -> &HashSet<String> {
224 &self.selected
225 }
226
227 /// Build a Spectator instance from the builder
228 ///
229 /// This initiates the connection to the Spectator instance.
230 ///
231 /// # Example
232 ///
233 /// ```
234 /// # use argmin_observer_spectator::SpectatorBuilder;
235 /// let spectator = SpectatorBuilder::new().build();
236 /// ```
237 pub fn build(self) -> Spectator {
238 let (tx, rx) = tokio::sync::mpsc::channel(self.capacity);
239 let thread_handle = std::thread::spawn(move || sender(rx, self.host, self.port));
240
241 Spectator {
242 tx,
243 name: self.name,
244 sending: true,
245 selected: self.selected,
246 thread_handle: Some(thread_handle),
247 }
248 }
249}
250
251/// Observer which sends data to Spectator
252// No #[derive(Clone)] on purpose: A clone will only overwrite information already present in the
253// Spectator since the name cannot be changed.
254pub struct Spectator {
255 tx: tokio::sync::mpsc::Sender<Message>,
256 name: String,
257 sending: bool,
258 selected: HashSet<String>,
259 thread_handle: Option<JoinHandle<Result<(), Error>>>,
260}
261
262impl Spectator {
263 /// Places a `Message` on the sending queue
264 fn send_msg(&mut self, message: Message) {
265 if self.sending {
266 if let Err(e) = self.tx.blocking_send(message) {
267 eprintln!("Can't send to Spectator: {e}. Will stop trying.");
268 self.sending = false;
269 }
270 }
271 }
272
273 /// Returns the name of the Spectator instance
274 ///
275 /// # Example
276 ///
277 /// ```
278 /// # use argmin_observer_spectator::SpectatorBuilder;
279 /// # let spectator = SpectatorBuilder::new().with_name("flup").build();
280 /// let name = spectator.name();
281 /// # assert_eq!(name, &"flup".to_string());
282 /// ```
283 pub fn name(&self) -> &String {
284 &self.name
285 }
286}
287
288impl<I> Observe<I> for Spectator
289where
290 I: State,
291 I::Param: IntoIterator<Item = I::Float> + Clone,
292 I::Float: ArgminFloat,
293 f64: From<I::Float>,
294{
295 /// Log basic information about the optimization after initialization.
296 fn observe_init(&mut self, name: &str, state: &I, kv: &KV) -> Result<(), Error> {
297 let init_param = state.get_param().map(|init_param| {
298 init_param
299 .clone()
300 .into_iter()
301 .map(f64::from)
302 .collect::<Vec<_>>()
303 });
304
305 let message = Message::NewRun {
306 name: self.name.clone(),
307 solver: name.to_string(),
308 max_iter: state.get_max_iters(),
309 target_cost: f64::from(state.get_target_cost()),
310 init_param,
311 settings: kv.clone(),
312 selected: self.selected.clone(),
313 };
314
315 self.send_msg(message);
316
317 Ok(())
318 }
319
320 /// Logs information about the progress of the optimization after every iteration.
321 fn observe_iter(&mut self, state: &I, kv: &KV) -> Result<(), Error> {
322 let mut kv = kv.clone();
323 let iter = state.get_iter();
324 kv.insert("best_cost", state.get_best_cost().into());
325 kv.insert("cost", state.get_cost().into());
326 kv.insert("iter", iter.into());
327
328 let message_samples = Message::Samples {
329 name: self.name.clone(),
330 iter,
331 time: Duration::try_from(
332 state
333 .get_time()
334 .unwrap_or(std::time::Duration::from_secs(0)),
335 )?,
336 termination_status: state.get_termination_status().clone(),
337 kv,
338 };
339
340 self.send_msg(message_samples);
341
342 let message_func_counts = Message::FuncCounts {
343 name: self.name.clone(),
344 iter,
345 kv: state.get_func_counts().clone(),
346 };
347
348 self.send_msg(message_func_counts);
349
350 if let Some(param) = state.get_param() {
351 let param = param.clone().into_iter().map(f64::from).collect::<Vec<_>>();
352
353 let message_param = Message::Param {
354 name: self.name.clone(),
355 iter,
356 param,
357 };
358
359 self.send_msg(message_param);
360 }
361
362 if state.is_best() {
363 if let Some(best_param) = state.get_best_param() {
364 let best_param = best_param
365 .clone()
366 .into_iter()
367 .map(f64::from)
368 .collect::<Vec<_>>();
369
370 let message_best_param = Message::BestParam {
371 name: self.name.clone(),
372 iter,
373 param: best_param,
374 };
375
376 self.send_msg(message_best_param);
377 }
378 }
379
380 Ok(())
381 }
382
383 /// Forwards termination reason to spectator
384 fn observe_final(&mut self, state: &I) -> Result<(), Error> {
385 let message = Message::Termination {
386 name: self.name.clone(),
387 termination_status: state.get_termination_status().clone(),
388 };
389 self.send_msg(message);
390 Ok(())
391 }
392}
393
394impl Drop for Spectator {
395 fn drop(&mut self) {
396 // This allows the observer to finish sending message to spectator, while making sure that
397 // it doesn't get stuck when the solver terminates unexpectedly.
398 let message = Message::Termination {
399 name: self.name.clone(),
400 termination_status: TerminationStatus::Terminated(TerminationReason::SolverExit(
401 "Aborted".into(),
402 )),
403 };
404 self.send_msg(message);
405 self.thread_handle
406 .take()
407 .map(JoinHandle::join)
408 .unwrap()
409 .unwrap()
410 .unwrap();
411 }
412}