clu_middleware_tron/
rabbitmq.rs

1/*
2 *  @Author: José Sánchez-Gallego (gallegoj@uw.edu)
3 *  @Date: 2025-11-22
4 *  @Filename: rabbitmq.rs
5 *  @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
6 */
7
8use std::sync::Arc;
9
10use async_channel::{Receiver, Sender};
11use bytes::BytesMut;
12use futures_lite::StreamExt;
13use lapin::{
14    Connection, ConnectionProperties, message::Delivery, options::*, types::AMQPValue,
15    types::FieldTable,
16};
17use tokio::sync::Mutex;
18
19use crate::{parser::Reply, tool::CommandID};
20
21/// Configuration options for the RabbitMQ service.
22#[derive(Clone)]
23pub struct RabbitMQConfig {
24    /// Actor name
25    pub actor_name: String,
26    /// The RabbitMQ server URL.
27    pub uri: String,
28    /// The exchange to publish messages to.
29    pub exchange: String,
30    /// Whether to monitor TCP replies and publish them to RabbitMQ.
31    pub monitor_tcp_replies: bool,
32}
33
34impl RabbitMQConfig {
35    /// Creates a default RabbitMQ configuration for an `actor_name`.
36    pub fn default(actor_name: String) -> Self {
37        Self {
38            actor_name,
39            uri: String::from("amqp://localhost:5672"),
40            exchange: String::from("sdss_exchange"),
41            monitor_tcp_replies: true,
42        }
43    }
44}
45
46/// Handles replies received from the RabbitMQ receiver channel and publishes them to RabbitMQ.
47///
48/// # Arguments
49/// * `channel` - The Lapin channel to use for publishing messages.
50/// * `config` - The RabbitMQ configuration.
51/// * `rabbitmq_receiver` - The receiver channel from which replies are received.
52/// * `command_id_pool_mutex` - A mutex-protected CommandID pool for tracking command UUIDs.
53///
54pub async fn handle_replies(
55    channel: lapin::Channel,
56    config: RabbitMQConfig,
57    rabbitmq_receiver: Receiver<Reply>,
58    command_id_pool_mutex: Arc<Mutex<CommandID>>,
59) -> Result<(), String> {
60    loop {
61        // Loop to receive replies from the queue. These messages are put here by the TCP client handler.
62        match rabbitmq_receiver.recv().await {
63            Ok(reply) => {
64                // Lock the command ID pool to access UUIDs and commanders.
65                let mut command_id_pool = command_id_pool_mutex.lock().await;
66
67                log::debug!("Publishing reply from actor to RabbitMQ");
68
69                // Actor sending the message, i.e., us.
70                let sender = config.actor_name.clone();
71
72                // The numeric command ID from the reply.
73                let command_id = reply.command_id;
74
75                // Get the UUID associated with the command ID from the pool.
76                // If the command ID is not found or this is a broadcast (command_id=0)
77                // use a default UUID.
78                let mut uuid = String::from("00000000-0000-0000-0000-000000000000");
79                let mut is_broadcast = false;
80                if let Some(found_uuid) = command_id_pool.get_uuid(command_id as u16) {
81                    uuid = found_uuid.clone();
82                } else if command_id == 0 {
83                    log::debug!("Command ID is 0, using broadcast UUID for reply publishing.");
84                    is_broadcast = true;
85                } else {
86                    log::warn!("No UUID found for command ID {}.", command_id);
87                    is_broadcast = true;
88                };
89
90                // Get the commander ID associated with the UUID from the pool. This is a '.'
91                // concatenated string of the client chain that sent the command and the consumer
92                // (i.e., us). If not found, use "broadcast".
93                let commander_id = if let Some(commander) = command_id_pool.get_commander(&uuid) {
94                    commander.clone()
95                } else {
96                    String::from("broadcast")
97                };
98
99                // Prepare the header fields.
100                let mut headers = FieldTable::default();
101                headers.insert(
102                    "message_code".into(),
103                    AMQPValue::LongString(reply.code.to_string().into()),
104                );
105                headers.insert(
106                    "command_id".into(),
107                    AMQPValue::LongString(uuid.clone().into()),
108                );
109                headers.insert(
110                    "commander_id".into(),
111                    AMQPValue::LongString(commander_id.clone().into()),
112                );
113                headers.insert("sender".into(), AMQPValue::LongString(sender.into()));
114                headers.insert("internal".into(), AMQPValue::Boolean(false));
115
116                // Prepare the properties. Correlation ID is the command UUID.
117                let properties = lapin::BasicProperties::default()
118                    .with_content_type("application/json".into())
119                    .with_headers(headers)
120                    .with_correlation_id(uuid.clone().into());
121
122                // Routing key is 'reply.<commander_id>' unless this is a broadcast,
123                // in which case it is 'reply.broadcast'.
124                let routing_key = if is_broadcast {
125                    "reply.broadcast".to_string()
126                } else {
127                    format!("reply.{}", commander_id)
128                };
129
130                // Serialize the keywords to JSON for the message payload.
131                let payload = serde_json::to_vec(&reply.keywords).unwrap();
132
133                // Publish the message to RabbitMQ.
134                if let Err(e) = channel
135                    .basic_publish(
136                        &config.exchange,
137                        routing_key.as_str(),
138                        BasicPublishOptions::default(),
139                        payload.as_slice(),
140                        properties,
141                    )
142                    .await
143                {
144                    log::error!("Failed to publish reply to RabbitMQ: {}", e);
145                }
146
147                // If the reply code indicates the command is finished, return the command ID to the pool.
148                if reply.code == ':' || reply.code.eq_ignore_ascii_case(&'f') {
149                    log::debug!(
150                        "Command {} is finished with code '{}'. Returning command_id to the pool.",
151                        reply.command_id,
152                        reply.code,
153                    );
154                    command_id_pool.finish_command(command_id as u16);
155                }
156            }
157            Err(_) => {
158                log::warn!("RabbitMQ receiver channel closed");
159                return Ok(());
160            }
161        }
162    }
163}
164
165/// Starts the RabbitMQ service to listen for commands and handle replies.
166///
167/// # Arguments
168/// * `config` - The RabbitMQ configuration.
169/// * `tcp_sender` - The sender channel to send TCP commands.
170/// * `rabbitmq_receiver` - The receiver channel to receive replies from TCP.
171///
172pub async fn start_rabbitmq_service(
173    config: RabbitMQConfig,
174    tcp_sender: Sender<BytesMut>,
175    rabbitmq_receiver: Receiver<Reply>,
176) -> Result<(), String> {
177    // Create a mutex-protected CommandID pool.
178    let command_id_pool_mutex = Arc::new(Mutex::new(CommandID::new()));
179
180    // Connect to RabbitMQ.
181    let connection_properties = ConnectionProperties::default();
182    let connection = match Connection::connect(&config.uri, connection_properties).await {
183        Ok(conn) => conn,
184        Err(e) => {
185            log::error!("Failed to connect to RabbitMQ: {}", e);
186            return Err(e.to_string());
187        }
188    };
189
190    log::debug!("Connected to RabbitMQ at {}", config.uri);
191
192    // Create a channel.
193    let channel = match connection.create_channel().await {
194        Ok(channel) => channel,
195        Err(e) => {
196            log::error!("Failed to create RabbitMQ channel: {}", e);
197            return Err(e.to_string());
198        }
199    };
200    log::debug!("Created RabbitMQ channel");
201
202    // Declare a queue for receiving commands. It needs to be exclusive and not auto-delete.
203    let queue_options = lapin::options::QueueDeclareOptions {
204        auto_delete: false,
205        exclusive: true,
206        ..Default::default()
207    };
208    let queue = match channel
209        .queue_declare(
210            format!("{}_commands", config.actor_name).as_str(),
211            queue_options,
212            FieldTable::default(),
213        )
214        .await
215    {
216        Ok(q) => q,
217        Err(e) => {
218            log::error!("Failed to declare RabbitMQ queue: {}", e);
219            return Err(e.to_string());
220        }
221    };
222    log::debug!("Declared RabbitMQ queue '{}'", queue.name().as_str());
223
224    // Declare the exchange. This needs to be auto-delete but not durable.
225    let exchange_declare_options = ExchangeDeclareOptions {
226        auto_delete: true,
227        ..Default::default()
228    };
229    match channel
230        .exchange_declare(
231            &config.exchange,
232            lapin::ExchangeKind::Topic,
233            exchange_declare_options,
234            FieldTable::default(),
235        )
236        .await
237    {
238        Ok(_) => (),
239        Err(e) => {
240            log::error!("Failed to declare RabbitMQ exchange: {}", e);
241            return Err(e.to_string());
242        }
243    };
244    log::debug!("Declared RabbitMQ exchange '{}'", config.exchange.as_str());
245
246    // Bind the queue to the exchange with the routing key 'command.<actor_name>.#'.
247    match channel
248        .queue_bind(
249            queue.name().as_str(),
250            &config.exchange,
251            format!("command.{}.#", config.actor_name).as_str(),
252            QueueBindOptions::default(),
253            FieldTable::default(),
254        )
255        .await
256    {
257        Ok(_) => (),
258        Err(e) => {
259            log::error!("Failed to bind RabbitMQ queue: {}", e);
260            return Err(e.to_string());
261        }
262    };
263    log::debug!(
264        "Bound RabbitMQ queue '{}' to exchange '{}' with routing key 'command.{}.#'",
265        queue.name().as_str(),
266        config.exchange.as_str(),
267        config.actor_name.as_str()
268    );
269
270    let mut consumer = match channel
271        .basic_consume(
272            queue.name().as_str(),
273            format!("{}_consumer", config.actor_name).as_str(),
274            BasicConsumeOptions::default(),
275            FieldTable::default(),
276        )
277        .await
278    {
279        Ok(consumer) => consumer,
280        Err(e) => {
281            log::error!("Failed to start RabbitMQ consumer: {}", e);
282            return Err(e.to_string());
283        }
284    };
285    log::debug!("Started RabbitMQ consumer");
286
287    // If monitoring TCP replies, spawn a task to handle them. This needs a separate channel
288    // since it will be moved into the task. But we want to create the channel here after declaring
289    // the exchange and queue. Not totally sure why this is necessary, but if one creates another
290    // channel before declaring the exchange the channel closes immediately. So either one does
291    // this or declares the exchange for the channel used for publishing replies.
292    if config.monitor_tcp_replies {
293        let channel_b = match connection.create_channel().await {
294            Ok(channel) => channel,
295            Err(e) => {
296                log::error!("Failed to create RabbitMQ channel: {}", e);
297                return Err(e.to_string());
298            }
299        };
300
301        // Clone the config for the task.
302        let config_clone = config.clone();
303
304        // Clone the command ID pool mutex for the task.
305        let command_id_pool_clone = command_id_pool_mutex.clone();
306
307        tokio::spawn(async move {
308            if let Err(e) = handle_replies(
309                channel_b,
310                config_clone,
311                rabbitmq_receiver,
312                command_id_pool_clone,
313            )
314            .await
315            {
316                // Non-fatal error, just log it, but maybe we should abort the whole service?
317                log::error!("Failed to handle RabbitMQ replies: {}", e);
318            }
319        });
320    }
321
322    // Process incoming commands.
323    while let Some(delivery) = consumer.next().await {
324        let delivery = delivery.unwrap();
325        match delivery.ack(BasicAckOptions::default()).await {
326            Ok(_) => (),
327            Err(e) => {
328                log::error!("Failed to acknowledge RabbitMQ message: {}", e);
329                continue;
330            }
331        }
332        process_command(&tcp_sender, &delivery, &command_id_pool_mutex).await;
333    }
334
335    Ok(())
336}
337
338fn get_header_value(delivery: &Delivery, key: &str) -> Option<AMQPValue> {
339    if let Some(headers) = delivery.properties.headers() {
340        headers.inner().get(key).cloned()
341    } else {
342        None
343    }
344}
345
346/// Processes a command received from RabbitMQ and sends it to the actor.
347///
348/// # Arguments
349/// * `tcp_sender` - The sender channel to send TCP commands.
350/// * `delivery` - The RabbitMQ delivery containing the command.
351/// * `command_id_pool` - A mutex-protected CommandID pool for tracking command UUIDs.
352///
353pub async fn process_command(
354    tcp_sender: &Sender<BytesMut>,
355    delivery: &Delivery,
356    command_id_pool: &Arc<Mutex<CommandID>>,
357) {
358    // Extract command_id and commander_id from headers.
359    let command_id = get_header_value(delivery, "command_id");
360    let command_id = match command_id {
361        Some(value) => match value {
362            AMQPValue::LongString(id) => id.to_string(),
363            _ => {
364                log::warn!("Command ID header is not a LongString");
365                return;
366            }
367        },
368        None => {
369            log::warn!("Command ID not found in message headers");
370            return;
371        }
372    };
373
374    let commander_id = get_header_value(delivery, "commander_id");
375    let commander_id = match commander_id {
376        Some(value) => match value {
377            AMQPValue::LongString(id) => id.to_string(),
378            _ => {
379                log::warn!("Commander ID header is not a LongString");
380                return;
381            }
382        },
383        None => {
384            log::warn!("Commander ID not found in message headers");
385            return;
386        }
387    };
388
389    // Deserialize the command payload.
390    let command_string: CommandPayload = serde_json::from_slice(&delivery.data).unwrap();
391
392    log::debug!(
393        "Processing command from commander {} with UUID '{}' and command string '{}'",
394        commander_id,
395        command_id,
396        command_string.command_string
397    );
398
399    // Get a free TCP command ID from the pool.
400    let mut command_id_pool = command_id_pool.lock().await;
401    let tcp_command_id = command_id_pool.get_command_id();
402    log::debug!(
403        "Mapped UUID '{}' to TCP command ID {}",
404        command_id,
405        tcp_command_id
406    );
407
408    // Prepare the TCP command string: "<tcp_command_id> <command_string>" and send it to the TCP queue.
409    let tcp_command_string = format!("{} {}", tcp_command_id, command_string.command_string);
410
411    log::debug!("Queuing TCP command: '{}'", tcp_command_string);
412
413    tcp_sender
414        .send(BytesMut::from(tcp_command_string.as_bytes()))
415        .await
416        .unwrap();
417
418    // Register the mapping of TCP command ID to UUID and commander ID.
419    command_id_pool.register_command(&command_id, &commander_id, tcp_command_id);
420}
421
422/// Represents the payload of a command message.
423#[derive(serde::Deserialize, Debug)]
424struct CommandPayload {
425    /// The actual command string to be executed by the actor.
426    command_string: String,
427}