From d9288c2b54fe513ab565fe576e97856f6581f0b0 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 21 Aug 2024 17:26:22 +0200 Subject: [PATCH] Fix #9 --- src/endpoint/context.rs | 77 +++++++++++++++++++++++------------------ src/endpoint/futures.rs | 2 +- 2 files changed, 45 insertions(+), 34 deletions(-) diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 970391e..a92fcd5 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -127,45 +127,56 @@ impl ContextInternal { pub fn input(&self) -> impl Future { let mut inner_lock = must_lock!(self.inner); - let input_result = - inner_lock - .vm - .sys_input() - .map_err(ErrorInner::VM) - .and_then(|raw_input| { - let headers = http::HeaderMap::::try_from( - &raw_input - .headers - .into_iter() - .map(|h| (h.key.to_string(), h.value.to_string())) - .collect::>(), - ) - .map_err(|e| ErrorInner::Deserialization { - syscall: "input_headers", - err: e.into(), - })?; + let input_result = inner_lock + .vm + .sys_input() + .map_err(ErrorInner::VM) + .map(|raw_input| { + let headers = http::HeaderMap::::try_from( + &raw_input + .headers + .into_iter() + .map(|h| (h.key.to_string(), h.value.to_string())) + .collect::>(), + ) + .map_err(|e| { + TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}")) + })?; - Ok(( - T::deserialize(&mut (raw_input.input.into())).map_err(|e| { - ErrorInner::Deserialization { - syscall: "input", - err: e.into(), - } - })?, - InputMetadata { - invocation_id: raw_input.invocation_id, - random_seed: raw_input.random_seed, - key: raw_input.key, - headers, - }, - )) - }); + Ok::<_, TerminalError>(( + T::deserialize(&mut (raw_input.input.into())).map_err(|e| { + TerminalError::new_with_code( + 400, + format!("Cannot decode input payload: {e:?}"), + ) + })?, + InputMetadata { + invocation_id: raw_input.invocation_id, + random_seed: raw_input.random_seed, + key: raw_input.key, + headers, + }, + )) + }); match input_result { - Ok(i) => { + Ok(Ok(i)) => { drop(inner_lock); return Either::Left(ready(i)); } + Ok(Err(err)) => { + let error_inner = ErrorInner::Deserialization { + syscall: "input", + err: err.0.clone().into(), + }; + let _ = inner_lock + .vm + .sys_write_output(NonEmptyValue::Failure(err.into())); + let _ = inner_lock.vm.sys_end(); + // This causes the trap, plus logs the error + inner_lock.handler_state.mark_error(error_inner.into()); + drop(inner_lock); + } Err(e) => { inner_lock.fail(e.into()); drop(inner_lock); diff --git a/src/endpoint/futures.rs b/src/endpoint/futures.rs index e37656e..4836feb 100644 --- a/src/endpoint/futures.rs +++ b/src/endpoint/futures.rs @@ -23,7 +23,7 @@ unsafe impl Sync for TrapFuture {} impl Future for TrapFuture { type Output = T; - fn poll(self: Pin<&mut Self>, ctx: &mut std::task::Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { ctx.waker().wake_by_ref(); Poll::Pending }