argmin_observer_spectator/
observer.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::{collections::HashSet, thread::JoinHandle};
9
10use anyhow::Error;
11use argmin::core::{
12    observers::Observe, ArgminFloat, State, TerminationReason, TerminationStatus, KV,
13};
14use spectator::{Message, DEFAULT_PORT};
15use time::Duration;
16use uuid::Uuid;
17
18use crate::sender::sender;
19
20const DEFAULT_HOST: &str = "127.0.0.1";
21
22/// Builder for the Spectator observer
23///
24/// # Example
25///
26/// ```
27/// use argmin_observer_spectator::SpectatorBuilder;
28///
29/// let spectator = SpectatorBuilder::new()
30///     // Optional: Name the optimization run
31///     // Default: random UUID.
32///     .with_name("optimization_run_1")
33///     // Optional, defaults to 127.0.0.1
34///     .with_host("127.0.0.1")
35///     // Optional, defaults to 5498
36///     .with_port(5498)
37///     // Choose which metrics should automatically be selected.
38///     // If omitted, all metrics will be selected.
39///     .select(&["cost", "best_cost"])
40///     // Build Spectator observer
41///     .build();
42/// ```
43pub struct SpectatorBuilder {
44    name: String,
45    selected: HashSet<String>,
46    capacity: usize,
47    host: String,
48    port: u16,
49}
50
51impl Default for SpectatorBuilder {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl SpectatorBuilder {
58    /// Creates a new `SpectatorBuilder`
59    pub fn new() -> Self {
60        SpectatorBuilder {
61            name: Uuid::new_v4().to_string(),
62            selected: HashSet::new(),
63            capacity: 10_000,
64            host: DEFAULT_HOST.to_string(),
65            port: DEFAULT_PORT,
66        }
67    }
68
69    /// Set a name the optimization run will be identified with
70    ///
71    /// Defaults to a random UUID.
72    ///
73    /// # Example
74    ///
75    /// ```
76    /// # use argmin_observer_spectator::SpectatorBuilder;
77    /// let builder = SpectatorBuilder::new().with_name("optimization_run_1");
78    /// # assert_eq!(builder.name().clone(), "optimization_run_1".to_string());
79    /// ```
80    pub fn with_name<T: AsRef<str>>(mut self, name: T) -> Self {
81        self.name = name.as_ref().to_string();
82        self
83    }
84
85    /// Set the host argmin spectator is running on.
86    ///
87    /// Defaults to 127.0.0.1.
88    ///
89    /// # Example
90    ///
91    /// ```
92    /// # use argmin_observer_spectator::SpectatorBuilder;
93    /// let builder = SpectatorBuilder::new().with_host("192.168.0.1");
94    /// # assert_eq!(builder.host().clone(), "192.168.0.1".to_string());
95    /// ```
96    pub fn with_host<T: AsRef<str>>(mut self, host: T) -> Self {
97        self.host = host.as_ref().to_string();
98        self
99    }
100
101    /// Set the port Spectator is running on.
102    ///
103    /// Defaults to 5498.
104    ///
105    /// # Example
106    ///
107    /// ```
108    /// # use argmin_observer_spectator::SpectatorBuilder;
109    /// let builder = SpectatorBuilder::new().with_port(1234);
110    /// # assert_eq!(builder.port(), 1234);
111    /// ```
112    pub fn with_port(mut self, port: u16) -> Self {
113        self.port = port;
114        self
115    }
116
117    /// Set the channel capacity
118    ///
119    /// A channel is used to queue messages for sending to Spectator. If the channel
120    /// capacity is reached backpressure will be applied, effectively blocking the optimization.
121    /// Defaults to 10000. Decrease this value in case memory consumption is too high and increase
122    /// the value in case blocking causes negative effects.
123    ///
124    /// # Example
125    ///
126    /// ```
127    /// # use argmin_observer_spectator::SpectatorBuilder;
128    /// let builder = SpectatorBuilder::new().with_channel_capacity(1000);
129    /// # assert_eq!(builder.channel_capacity(), 1000);
130    /// ```
131    pub fn with_channel_capacity(mut self, capacity: usize) -> Self {
132        self.capacity = capacity;
133        self
134    }
135
136    /// Define which metrics will be selected in Spectator by default
137    ///
138    /// If none are set, all metrics will be selected and shown. Providing zero or more metrics
139    /// via `select` disables all apart from the provided ones. Note that independent of this
140    /// setting, all data will be sent, and metrics can be selected and deselected via the
141    /// Spectator GUI.
142    ///
143    /// # Example
144    ///
145    /// ```
146    /// # use argmin_observer_spectator::SpectatorBuilder;
147    /// # use std::collections::HashSet;
148    /// let builder = SpectatorBuilder::new().select(&["cost", "best_cost"]);
149    /// # assert_eq!(builder.selected(), &HashSet::from(["cost".to_string(), "best_cost".to_string()]));
150    /// ```
151    pub fn select<T: AsRef<str>>(mut self, metrics: &[T]) -> Self {
152        self.selected = metrics.iter().map(|s| s.as_ref().to_string()).collect();
153        self
154    }
155
156    /// Returns the name of the optimization run
157    ///
158    /// # Example
159    ///
160    /// ```
161    /// # use argmin_observer_spectator::SpectatorBuilder;
162    /// # let builder = SpectatorBuilder::new().with_name("test");
163    /// let name = builder.name();
164    /// # assert_eq!(name, &"test".to_string());
165    /// ```
166    pub fn name(&self) -> &String {
167        &self.name
168    }
169
170    /// Returns the host this observer will connect to
171    ///
172    /// # Example
173    ///
174    /// ```
175    /// # use argmin_observer_spectator::SpectatorBuilder;
176    /// # let builder = SpectatorBuilder::new();
177    /// let host = builder.host();
178    /// # assert_eq!(host, &"127.0.0.1".to_string());
179    /// ```
180    pub fn host(&self) -> &String {
181        &self.host
182    }
183
184    /// Returns the port this observer will connect to
185    ///
186    /// # Example
187    ///
188    /// ```
189    /// # use argmin_observer_spectator::SpectatorBuilder;
190    /// # let builder = SpectatorBuilder::new();
191    /// let port = builder.port();
192    /// # assert_eq!(port, 5498);
193    /// ```
194    pub fn port(&self) -> u16 {
195        self.port
196    }
197
198    /// Returns the channel capacity
199    ///
200    /// # Example
201    ///
202    /// ```
203    /// # use argmin_observer_spectator::SpectatorBuilder;
204    /// # let builder = SpectatorBuilder::new();
205    /// let capacity = builder.channel_capacity();
206    /// # assert_eq!(capacity, 10000);
207    /// ```
208    pub fn channel_capacity(&self) -> usize {
209        self.capacity
210    }
211
212    /// Returns the selected metrics
213    ///
214    /// # Example
215    ///
216    /// ```
217    /// # use argmin_observer_spectator::SpectatorBuilder;
218    /// # use std::collections::HashSet;
219    /// # let builder = SpectatorBuilder::new().select(&["cost", "best_cost"]);
220    /// let selected = builder.selected();
221    /// # assert_eq!(selected, &HashSet::from(["cost".to_string(), "best_cost".to_string()]));
222    /// ```
223    pub fn selected(&self) -> &HashSet<String> {
224        &self.selected
225    }
226
227    /// Build a Spectator instance from the builder
228    ///
229    /// This initiates the connection to the Spectator instance.
230    ///
231    /// # Example
232    ///
233    /// ```
234    /// # use argmin_observer_spectator::SpectatorBuilder;
235    /// let spectator = SpectatorBuilder::new().build();
236    /// ```
237    pub fn build(self) -> Spectator {
238        let (tx, rx) = tokio::sync::mpsc::channel(self.capacity);
239        let thread_handle = std::thread::spawn(move || sender(rx, self.host, self.port));
240
241        Spectator {
242            tx,
243            name: self.name,
244            sending: true,
245            selected: self.selected,
246            thread_handle: Some(thread_handle),
247        }
248    }
249}
250
251/// Observer which sends data to Spectator
252// No #[derive(Clone)] on purpose: A clone will only overwrite information already present in the
253// Spectator since the name cannot be changed.
254pub struct Spectator {
255    tx: tokio::sync::mpsc::Sender<Message>,
256    name: String,
257    sending: bool,
258    selected: HashSet<String>,
259    thread_handle: Option<JoinHandle<Result<(), Error>>>,
260}
261
262impl Spectator {
263    /// Places a `Message` on the sending queue
264    fn send_msg(&mut self, message: Message) {
265        if self.sending {
266            if let Err(e) = self.tx.blocking_send(message) {
267                eprintln!("Can't send to Spectator: {e}. Will stop trying.");
268                self.sending = false;
269            }
270        }
271    }
272
273    /// Returns the name of the Spectator instance
274    ///
275    /// # Example
276    ///
277    /// ```
278    /// # use argmin_observer_spectator::SpectatorBuilder;
279    /// # let spectator = SpectatorBuilder::new().with_name("flup").build();
280    /// let name = spectator.name();
281    /// # assert_eq!(name, &"flup".to_string());
282    /// ```
283    pub fn name(&self) -> &String {
284        &self.name
285    }
286}
287
288impl<I> Observe<I> for Spectator
289where
290    I: State,
291    I::Param: IntoIterator<Item = I::Float> + Clone,
292    I::Float: ArgminFloat,
293    f64: From<I::Float>,
294{
295    /// Log basic information about the optimization after initialization.
296    fn observe_init(&mut self, name: &str, state: &I, kv: &KV) -> Result<(), Error> {
297        let init_param = state.get_param().map(|init_param| {
298            init_param
299                .clone()
300                .into_iter()
301                .map(f64::from)
302                .collect::<Vec<_>>()
303        });
304
305        let message = Message::NewRun {
306            name: self.name.clone(),
307            solver: name.to_string(),
308            max_iter: state.get_max_iters(),
309            target_cost: f64::from(state.get_target_cost()),
310            init_param,
311            settings: kv.clone(),
312            selected: self.selected.clone(),
313        };
314
315        self.send_msg(message);
316
317        Ok(())
318    }
319
320    /// Logs information about the progress of the optimization after every iteration.
321    fn observe_iter(&mut self, state: &I, kv: &KV) -> Result<(), Error> {
322        let mut kv = kv.clone();
323        let iter = state.get_iter();
324        kv.insert("best_cost", state.get_best_cost().into());
325        kv.insert("cost", state.get_cost().into());
326        kv.insert("iter", iter.into());
327
328        let message_samples = Message::Samples {
329            name: self.name.clone(),
330            iter,
331            time: Duration::try_from(
332                state
333                    .get_time()
334                    .unwrap_or(std::time::Duration::from_secs(0)),
335            )?,
336            termination_status: state.get_termination_status().clone(),
337            kv,
338        };
339
340        self.send_msg(message_samples);
341
342        let message_func_counts = Message::FuncCounts {
343            name: self.name.clone(),
344            iter,
345            kv: state.get_func_counts().clone(),
346        };
347
348        self.send_msg(message_func_counts);
349
350        if let Some(param) = state.get_param() {
351            let param = param.clone().into_iter().map(f64::from).collect::<Vec<_>>();
352
353            let message_param = Message::Param {
354                name: self.name.clone(),
355                iter,
356                param,
357            };
358
359            self.send_msg(message_param);
360        }
361
362        if state.is_best() {
363            if let Some(best_param) = state.get_best_param() {
364                let best_param = best_param
365                    .clone()
366                    .into_iter()
367                    .map(f64::from)
368                    .collect::<Vec<_>>();
369
370                let message_best_param = Message::BestParam {
371                    name: self.name.clone(),
372                    iter,
373                    param: best_param,
374                };
375
376                self.send_msg(message_best_param);
377            }
378        }
379
380        Ok(())
381    }
382
383    /// Forwards termination reason to spectator
384    fn observe_final(&mut self, state: &I) -> Result<(), Error> {
385        let message = Message::Termination {
386            name: self.name.clone(),
387            termination_status: state.get_termination_status().clone(),
388        };
389        self.send_msg(message);
390        Ok(())
391    }
392}
393
394impl Drop for Spectator {
395    fn drop(&mut self) {
396        // This allows the observer to finish sending message to spectator, while making sure that
397        // it doesn't get stuck when the solver terminates unexpectedly.
398        let message = Message::Termination {
399            name: self.name.clone(),
400            termination_status: TerminationStatus::Terminated(TerminationReason::SolverExit(
401                "Aborted".into(),
402            )),
403        };
404        self.send_msg(message);
405        self.thread_handle
406            .take()
407            .map(JoinHandle::join)
408            .unwrap()
409            .unwrap()
410            .unwrap();
411    }
412}