Skip to content

Commit

Permalink
graceful mid_process timeout; fixes #94
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Apr 18, 2024
1 parent 5a5a28f commit 25470c1
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 46 deletions.
1 change: 1 addition & 0 deletions aicirt/src/hostimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct AiciLimits {
pub max_step_ms: u64,
pub max_init_ms: u64,
pub max_compile_ms: u64,
pub max_timeout_steps: usize,
pub logit_memory_bytes: usize,
pub busy_wait_duration: Duration,
pub max_forks: usize,
Expand Down
57 changes: 50 additions & 7 deletions aicirt/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use crate::{
worker::{RtMidProcessArg, WorkerForker},
TimerSet,
};
use aici_abi::{bytes::limit_str, toktree::TokTrie, MidProcessArg, SeqId};
use aici_abi::{
bytes::limit_str, toktree::TokTrie, Branch, MidProcessArg, ProcessResultOffset, SeqId,
};
use aicirt::{bintokens::find_tokenizer, futexshm::ServerChannel, *};
use anyhow::{anyhow, ensure, Result};
use base64::{self, Engine as _};
Expand Down Expand Up @@ -119,9 +121,13 @@ struct Cli {
wasm_max_memory: usize,

/// Maximum time WASM module can execute step in milliseconds
#[arg(long, default_value = "150")]
#[arg(long, default_value = "50")]
wasm_max_step_time: u64,

/// How many steps have to timeout before the sequenace is terminated
#[arg(long, default_value = "10")]
wasm_max_timeout_steps: usize,

/// Maximum time WASM module can execute initialization code in milliseconds
#[arg(long, default_value = "1000")]
wasm_max_init_time: u64,
Expand Down Expand Up @@ -162,6 +168,7 @@ struct ModuleRegistry {
struct Stepper {
req_instances: Arc<Mutex<HashMap<String, SeqWorkerHandle>>>,
instances: HashMap<ModuleInstId, SeqWorkerHandle>,
num_timeouts: HashMap<ModuleInstId, usize>,
limits: AiciLimits,
globals: GlobalInfo,
// for debugging
Expand Down Expand Up @@ -623,6 +630,7 @@ impl Stepper {
Ok(Self {
req_instances: reg.req_instances.clone(),
instances: HashMap::default(),
num_timeouts: HashMap::default(),
limits,
globals: reg.wasm_ctx.globals.clone(),
shm,
Expand Down Expand Up @@ -744,6 +752,8 @@ impl Stepper {
let mask_num_bytes = slice.len() * 4;
slice.iter_mut().for_each(|v| *v = 0.0);

let start_time = Instant::now();

for op in req.ops.into_iter() {
let instid = op.id;
if let Ok(h) = self.get_worker(instid) {
Expand All @@ -763,10 +773,18 @@ impl Stepper {
logit_offset,
logit_size,
};
match h.start_process(op) {
Ok(_) => used_ids.push((logit_offset, instid)),
Err(e) => self.worker_error(instid, &mut outputs, e),
};
if self.num_timeouts.get(&instid).is_some() {
assert!(op.op.backtrack == 0);
assert!(op.op.tokens.is_empty());
// TODO logit_offset!
log::debug!("{instid} still pending (timeout in previous round)");
used_ids.push((logit_offset, instid));
} else {
match h.start_process(op) {
Ok(_) => used_ids.push((logit_offset, instid)),
Err(e) => self.worker_error(instid, &mut outputs, e),
}
}
} else {
log::info!("invalid id {}", instid);
}
Expand All @@ -776,6 +794,7 @@ impl Stepper {
let deadline = Instant::now() + std::time::Duration::from_millis(self.limits.max_step_ms);

for (off, id) in used_ids {
let prev_timeout = self.num_timeouts.remove(&id).unwrap_or(0);
let h = self.get_worker(id).unwrap();
let timeout = deadline.saturating_duration_since(Instant::now());
match h.check_process(timeout) {
Expand Down Expand Up @@ -813,7 +832,30 @@ impl Stepper {
log::trace!("logits: {} allow; tokens: {}", allow_set.len(), list);
}
}
Err(e) => self.worker_error(id, &mut outputs, e),
Err(e) => {
if e.to_string() == "timeout" && prev_timeout < self.limits.max_timeout_steps {
outputs.insert(
id,
SequenceResult {
result: Some(ProcessResultOffset {
branches: vec![Branch::noop()],
}),
error: String::new(),
storage: vec![],
logs: format!(
"⏲ timeout [deadline: {}ms; step {}/{}]\n",
self.limits.max_step_ms,
prev_timeout + 1,
self.limits.max_timeout_steps
),
micros: start_time.elapsed().as_micros() as u64,
},
);
self.num_timeouts.insert(id, prev_timeout + 1);
} else {
self.worker_error(id, &mut outputs, e)
}
}
}
}

Expand Down Expand Up @@ -1142,6 +1184,7 @@ fn main() -> () {
max_memory_bytes: cli.wasm_max_memory * MEGABYTE,
max_init_ms: cli.wasm_max_init_time,
max_step_ms: cli.wasm_max_step_time,
max_timeout_steps: cli.wasm_max_timeout_steps,
max_compile_ms: 10_000,
logit_memory_bytes: cli.bin_size * MEGABYTE,
busy_wait_duration: Duration::from_millis(cli.busy_wait_time),
Expand Down
103 changes: 64 additions & 39 deletions aicirt/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use aicirt::{
futexshm::{TypedClient, TypedClientHandle, TypedServer},
set_max_priority,
shm::Unlink,
user_error, variables::Variables,
user_error,
variables::Variables,
};
use anyhow::{anyhow, Result};
use libc::pid_t;
Expand All @@ -25,7 +26,8 @@ use std::{
time::{Duration, Instant},
};

const QUICK_OP_MS: u64 = 10;
const QUICK_OP_MS: u64 = 3;
const QUICK_OP_RETRY_MS: u64 = 100;

#[derive(Serialize, Deserialize, Debug)]
pub enum GroupCmd {
Expand Down Expand Up @@ -174,6 +176,18 @@ enum SeqResp {
Error { msg: String, is_user_error: bool },
}

pub enum Timeout {
Quick,
Strict(Duration),
Speculative(Duration),
}

impl Timeout {
pub fn from_millis(millis: u64) -> Self {
Timeout::Strict(Duration::from_millis(millis))
}
}

impl<Cmd, Resp> ProcessHandle<Cmd, Resp>
where
Cmd: for<'d> Deserialize<'d> + Serialize + Debug,
Expand Down Expand Up @@ -208,29 +222,42 @@ where
unsafe { libc::kill(self.pid, libc::SIGKILL) }
}

fn recv_with_timeout(&self, lbl: &str, timeout: Duration) -> Result<Resp> {
fn recv_with_timeout(&self, lbl: &str, timeout: Timeout) -> Result<Resp> {
let t0 = Instant::now();
match self.recv_with_timeout_inner(timeout) {
let d = match timeout {
Timeout::Quick => Duration::from_millis(QUICK_OP_MS),
Timeout::Strict(d) => d,
Timeout::Speculative(d) => d,
};
let r = match timeout {
Timeout::Quick => match self.recv_with_timeout_inner(d) {
None => {
match self.recv_with_timeout_inner(Duration::from_millis(QUICK_OP_RETRY_MS)) {
Some(r) => {
let dur = t0.elapsed();
log::warn!("{lbl}: slow quick op: {dur:?}");
Some(r)
}
None => None,
}
}
r => r,
},
_ => self.recv_with_timeout_inner(d),
};

match r {
Some(r) => Ok(r),
None => {
let second_try = Duration::from_millis(200);
let r = if timeout < second_try {
let r = self.recv_with_timeout_inner(second_try);
let dur = t0.elapsed();
log::warn!("{lbl}: timeout {dur:?} (allowed: {timeout:?})");
r
} else {
None
};
match r {
Some(r) => Ok(r),
None => {
None => match r {
Some(r) => Ok(r),
None => {
if !matches!(timeout, Timeout::Speculative(_)) {
let dur = t0.elapsed();
log::error!("{lbl}: timeout {dur:?} (allowed: {timeout:?})");
Err(anyhow!("timeout"))
log::error!("{lbl}: timeout {dur:?} (allowed: {d:?})");
}
Err(anyhow!("timeout"))
}
}
},
}
}

Expand All @@ -246,7 +273,7 @@ where
}

impl SeqHandle {
fn send_cmd_expect_ok(&self, cmd: SeqCmd, timeout: Duration) -> Result<()> {
fn send_cmd_expect_ok(&self, cmd: SeqCmd, timeout: Timeout) -> Result<()> {
let tag = cmd.tag();
self.just_send(cmd)?;
match self.seq_recv_with_timeout(tag, timeout) {
Expand All @@ -256,7 +283,7 @@ impl SeqHandle {
}
}

fn seq_recv_with_timeout(&self, lbl: &str, timeout: Duration) -> Result<SeqResp> {
fn seq_recv_with_timeout(&self, lbl: &str, timeout: Timeout) -> Result<SeqResp> {
match self.recv_with_timeout(lbl, timeout) {
Ok(SeqResp::Error { msg, is_user_error }) => {
if is_user_error {
Expand All @@ -269,7 +296,7 @@ impl SeqHandle {
}
}

fn send_cmd_with_timeout(&self, cmd: SeqCmd, timeout: Duration) -> Result<SeqResp> {
fn send_cmd_with_timeout(&self, cmd: SeqCmd, timeout: Timeout) -> Result<SeqResp> {
let tag = cmd.tag();
self.just_send(cmd)?;
self.seq_recv_with_timeout(tag, timeout)
Expand Down Expand Up @@ -431,31 +458,26 @@ impl Drop for SeqWorkerHandle {

impl SeqWorkerHandle {
pub fn set_id(&self, id: ModuleInstId) -> Result<()> {
self.handle.send_cmd_expect_ok(
SeqCmd::SetId { inst_id: id },
Duration::from_millis(QUICK_OP_MS),
)
self.handle
.send_cmd_expect_ok(SeqCmd::SetId { inst_id: id }, Timeout::Quick)
}

pub fn run_main(&self) -> Result<()> {
self.handle
.send_cmd_expect_ok(SeqCmd::RunMain {}, Duration::from_secs(120))
.send_cmd_expect_ok(SeqCmd::RunMain {}, Timeout::from_millis(120_000))
}

pub fn fork(&self, target_id: ModuleInstId) -> Result<SeqWorkerHandle> {
match self.handle.send_cmd_with_timeout(
SeqCmd::Fork { inst_id: target_id },
Duration::from_millis(QUICK_OP_MS),
)? {
match self
.handle
.send_cmd_with_timeout(SeqCmd::Fork { inst_id: target_id }, Timeout::Quick)?
{
SeqResp::Fork { handle } => {
let res = SeqWorkerHandle {
req_id: self.req_id.clone(),
handle: handle.to_client(),
};
match res
.handle
.recv_with_timeout("r-fork", Duration::from_millis(QUICK_OP_MS))?
{
match res.handle.recv_with_timeout("r-fork", Timeout::Quick)? {
SeqResp::Ok {} => Ok(res),
r => Err(anyhow!("unexpected response (fork, child) {r:?}")),
}
Expand All @@ -470,7 +492,10 @@ impl SeqWorkerHandle {
}

pub fn check_process(&self, timeout: Duration) -> Result<SequenceResult<ProcessResultOffset>> {
match self.handle.seq_recv_with_timeout("r-process", timeout) {
match self
.handle
.seq_recv_with_timeout("r-process", Timeout::Speculative(timeout))
{
Ok(SeqResp::MidProcess { json }) => Ok(serde_json::from_str(&json)?),
Ok(r) => Err(anyhow!("unexpected response (process) {r:?}")),
Err(e) => Err(e.into()),
Expand Down Expand Up @@ -682,7 +707,7 @@ impl WorkerForker {
prompt_str,
prompt_toks,
},
Duration::from_millis(self.limits.max_init_ms),
Timeout::from_millis(self.limits.max_init_ms),
)? {
SeqResp::InitPrompt { json } => {
let r: SequenceResult<()> = serde_json::from_str(&json)?;
Expand All @@ -706,7 +731,7 @@ impl WorkerForker {
};
match res.handle.send_cmd_with_timeout(
SeqCmd::Compile { wasm },
Duration::from_millis(self.limits.max_compile_ms),
Timeout::from_millis(self.limits.max_compile_ms),
)? {
SeqResp::Compile { binary } => Ok(binary),
r => Err(anyhow!("unexpected response (compile) {r:?}")),
Expand Down

0 comments on commit 25470c1

Please sign in to comment.