diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 251259e9ee..dbd15dbdab 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -25,6 +25,10 @@ path = "ws.rs" name = "ws_subscription" path = "ws_subscription.rs" +[[example]] +name = "ws_sub_with_params" +path = "ws_sub_with_params.rs" + [[example]] name = "proc_macro" path = "proc_macro.rs" diff --git a/examples/ws_sub_with_params.rs b/examples/ws_sub_with_params.rs new file mode 100644 index 0000000000..a8669467b8 --- /dev/null +++ b/examples/ws_sub_with_params.rs @@ -0,0 +1,73 @@ +// Copyright 2019 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use jsonrpsee::{ + ws_client::{traits::SubscriptionClient, v2::params::JsonRpcParams, WsClientBuilder}, + ws_server::WsServer, +}; +use std::net::SocketAddr; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + let addr = run_server().await?; + let url = format!("ws://{}", addr); + + let client = WsClientBuilder::default().build(&url).await?; + + // Subscription with a single parameter + let params = JsonRpcParams::Array(vec![3.into()]); + let mut sub_params_one = client.subscribe::>("sub_one_param", params, "unsub_one_param").await?; + println!("subscription with one param: {:?}", sub_params_one.next().await); + + // Subscription with multiple parameters + let params = JsonRpcParams::Array(vec![2.into(), 5.into()]); + let mut sub_params_two = client.subscribe::("sub_params_two", params, "unsub_params_two").await?; + println!("subscription with two params: {:?}", sub_params_two.next().await); + + Ok(()) +} + +async fn run_server() -> anyhow::Result { + const LETTERS: &'static str = "abcdefghijklmnopqrstuvxyz"; + let mut server = WsServer::new("127.0.0.1:0").await?; + let one_param = server.register_subscription("sub_one_param", "unsub_one_param").unwrap(); + let two_params = server.register_subscription("sub_params_two", "unsub_params_two").unwrap(); + + std::thread::spawn(move || loop { + let _ = one_param.send_each(|idx| Ok(LETTERS.chars().nth(*idx))); + std::thread::sleep(std::time::Duration::from_millis(50)); + }); + + std::thread::spawn(move || loop { + let _ = two_params.send_each(|params: &Vec| Ok(Some(LETTERS[params[0]..params[1]].to_string()))); + std::thread::sleep(std::time::Duration::from_millis(100)); + }); + + let addr = server.local_addr()?; + tokio::spawn(async move { server.start().await }); + Ok(addr) +} diff --git a/examples/ws_subscription.rs b/examples/ws_subscription.rs index 2571138c52..1f26953aee 100644 --- a/examples/ws_subscription.rs +++ b/examples/ws_subscription.rs @@ -54,10 +54,10 @@ async fn main() -> anyhow::Result<()> { async fn run_server() -> anyhow::Result { let mut server = WsServer::new("127.0.0.1:0").await?; - let mut subscription = server.register_subscription("subscribe_hello", "unsubscribe_hello").unwrap(); + let subscription = server.register_subscription::<()>("subscribe_hello", "unsubscribe_hello").unwrap(); std::thread::spawn(move || loop { - subscription.send(&"hello my friend").unwrap(); + subscription.broadcast(&"hello my friend").unwrap(); std::thread::sleep(std::time::Duration::from_secs(1)); }); diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 0adbaa4cd5..fc590550ec 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -25,7 +25,10 @@ // DEALINGS IN THE SOFTWARE. use futures_channel::oneshot; -use jsonrpsee::{http_server::HttpServerBuilder, ws_server::WsServer}; +use jsonrpsee::{ + http_server::HttpServerBuilder, + ws_server::{SubscriptionSink, WsServer}, +}; use std::net::SocketAddr; use std::time::Duration; @@ -36,9 +39,11 @@ pub async fn websocket_server_with_subscription() -> SocketAddr { let rt = tokio::runtime::Runtime::new().unwrap(); let mut server = rt.block_on(WsServer::new("127.0.0.1:0")).unwrap(); - let mut sub_hello = server.register_subscription("subscribe_hello", "unsubscribe_hello").unwrap(); - let mut sub_foo = server.register_subscription("subscribe_foo", "unsubscribe_foo").unwrap(); - + let sub_hello: SubscriptionSink<()> = + server.register_subscription("subscribe_hello", "unsubscribe_hello").unwrap(); + let sub_foo: SubscriptionSink<()> = server.register_subscription("subscribe_foo", "unsubscribe_foo").unwrap(); + let sub_add_one: SubscriptionSink = + server.register_subscription("subscribe_add_one", "unsubscribe_add_one").unwrap(); server.register_method("say_hello", |_| Ok("hello")).unwrap(); server_started_tx.send(server.local_addr().unwrap()).unwrap(); @@ -49,8 +54,9 @@ pub async fn websocket_server_with_subscription() -> SocketAddr { loop { tokio::time::sleep(Duration::from_millis(100)).await; - sub_hello.send(&"hello from subscription").unwrap(); - sub_foo.send(&1337_u64).unwrap(); + sub_hello.broadcast(&"hello from subscription").unwrap(); + sub_foo.broadcast(&1337_u64).unwrap(); + sub_add_one.send_each(|p| Ok(Some(*p + 1))).unwrap(); } }); }); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index e809ef8318..95f1f18b18 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -54,6 +54,20 @@ async fn ws_subscription_works() { } } +#[tokio::test] +async fn ws_subscription_with_input_works() { + let server_addr = websocket_server_with_subscription().await; + let server_url = format!("ws://{}", server_addr); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + let mut add_one: Subscription = + client.subscribe("subscribe_add_one", vec![1.into()].into(), "unsubscribe_add_one").await.unwrap(); + + for _ in 0..2 { + let two = add_one.next().await.unwrap(); + assert_eq!(two, 2); + } +} + #[tokio::test] async fn ws_method_call_works() { let server_addr = websocket_server().await; diff --git a/types/src/error.rs b/types/src/error.rs index cce5d5b5aa..61cb4066b9 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -40,7 +40,7 @@ pub enum Error { Request(String), /// Frontend/backend channel error. #[error("Frontend/backend channel error: {0}")] - Internal(#[source] futures_channel::mpsc::SendError), + Internal(#[from] futures_channel::mpsc::SendError), /// Invalid response, #[error("Invalid response: {0}")] InvalidResponse(Mismatch), diff --git a/types/src/v2/params.rs b/types/src/v2/params.rs index 90d3c137de..5caa12ba10 100644 --- a/types/src/v2/params.rs +++ b/types/src/v2/params.rs @@ -83,10 +83,8 @@ impl<'a> RpcParams<'a> { where T: Deserialize<'a>, { - match self.0 { - None => Err(CallError::InvalidParams), - Some(params) => serde_json::from_str(params).map_err(|_| CallError::InvalidParams), - } + let params = self.0.unwrap_or("null"); + serde_json::from_str(params).map_err(|_| CallError::InvalidParams) } /// Attempt to parse only the first parameter from an array into type T diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 00eae19ef4..abcb2f5606 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -8,8 +8,8 @@ use jsonrpsee_types::v2::response::JsonRpcSubscriptionResponse; use parking_lot::Mutex; use rustc_hash::FxHashMap; -use serde::Serialize; -use serde_json::value::to_raw_value; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json::value::{to_raw_value, RawValue}; use std::ops::{Deref, DerefMut}; use std::sync::Arc; @@ -28,9 +28,11 @@ pub type SubscriptionId = u64; /// Sink that is used to send back the result to the server for a specific method. pub type MethodSink = mpsc::UnboundedSender; -type Subscribers = Arc>>; +/// Map of subscribers keyed by the connection and subscription ids to an [`InnerSink`] that contains the parameters +/// they used to subscribe and the tx side of a channel used to convey results and errors back. +type Subscribers

= Arc>>>; -/// Sets of JSON-RPC methods can be organized into a "module" that are in turn registered on server or, +/// Sets of JSON-RPC methods can be organized into a "module"s that are in turn registered on the server or, /// alternatively, merged with other modules to construct a cohesive API. #[derive(Default)] pub struct RpcModule { @@ -88,12 +90,14 @@ impl RpcModule { Ok(()) } - /// Register a new RPC subscription, with subscribe and unsubscribe methods. - pub fn register_subscription( + /// Register a new RPC subscription, with subscribe and unsubscribe methods. Returns a [`SubscriptionSink`]. If a + /// method with the same name is already registered, an [`Error::MethodAlreadyRegistered`] is returned. + /// If the subscription does not take any parameters, set `Params` to `()`. + pub fn register_subscription( &mut self, subscribe_method_name: &'static str, unsubscribe_method_name: &'static str, - ) -> Result { + ) -> Result, Error> { if subscribe_method_name == unsubscribe_method_name { return Err(Error::SubscriptionNameConflict(subscribe_method_name.into())); } @@ -107,13 +111,21 @@ impl RpcModule { let subscribers = subscribers.clone(); self.methods.insert( subscribe_method_name, - Box::new(move |id, _, tx, conn| { + Box::new(move |id, params, tx, conn| { + let params = match params.parse().or_else(|_| params.one()) { + Ok(p) => p, + Err(err) => { + log::error!("Params={:?}, in subscription couldn't be parsed: {:?}", params, err); + return Err(err.into()); + } + }; let sub_id = { const JS_NUM_MASK: SubscriptionId = !0 >> 11; let sub_id = rand::random::() & JS_NUM_MASK; - subscribers.lock().insert((conn, sub_id), tx.clone()); + let inner = InnerSink { sink: tx.clone(), params, method: subscribe_method_name, sub_id }; + subscribers.lock().insert((conn, sub_id), inner); sub_id }; @@ -236,16 +248,18 @@ impl DerefMut for RpcContextModule { } /// Used by the server to send data back to subscribers. #[derive(Clone)] -pub struct SubscriptionSink { +pub struct SubscriptionSink { method: &'static str, - subscribers: Subscribers, + subscribers: Subscribers, } -impl SubscriptionSink { - /// Send data back to subscribers. - /// If a send fails (likely a broken connection) the subscriber is removed from the sink. - /// O(n) in the number of subscribers. - pub fn send(&mut self, result: &T) -> Result<(), Error> +impl SubscriptionSink { + /// Send a message to all subscribers. + /// + /// If you have subscriptions with params/input you should most likely + /// call `send_each` to the process the input/params and send out + /// the result on each subscription individually instead. + pub fn broadcast(&self, result: &T) -> Result<(), Error> where T: Serialize, { @@ -254,25 +268,114 @@ impl SubscriptionSink { let mut errored = Vec::new(); let mut subs = self.subscribers.lock(); - for ((conn_id, sub_id), sender) in subs.iter() { - let msg = serde_json::to_string(&JsonRpcSubscriptionResponse { - jsonrpc: TwoPointZero, - method: self.method, - params: JsonRpcNotificationParams { subscription: *sub_id, result: &*result }, - })?; - - // Track broken connections - if sender.unbounded_send(msg).is_err() { + for ((conn_id, sub_id), sink) in subs.iter() { + // Mark broken connections, to be removed. + if sink.send_raw_value(&result).is_err() { errored.push((*conn_id, *sub_id)); } } // Remove broken connections for entry in errored { + log::debug!("Dropping subscription on method: {}, id: {}", self.method, entry.1); + subs.remove(&entry); + } + + Ok(()) + } + + /// Send a message to all subscribers one by one, parsing the params they sent with the provided closure. If the + /// closure `F` fails to parse the params the message is not sent. + /// + /// F: is a closure that you need to provide to apply on the input P. + pub fn send_each(&self, f: F) -> Result<(), Error> + where + F: Fn(&Params) -> Result, Error>, + T: Serialize, + { + let mut subs = self.subscribers.lock(); + let mut errored = Vec::new(); + + for ((conn_id, sub_id), sink) in subs.iter() { + match f(&sink.params) { + Ok(Some(res)) => { + let result = match to_raw_value(&res) { + Ok(res) => res, + Err(err) => { + log::error!("Subscription: {} failed to serialize message: {:?}; ignoring", sub_id, err); + continue; + } + }; + + if sink.send_raw_value(&result).is_err() { + errored.push((*conn_id, *sub_id)); + } + } + Ok(None) => (), + Err(e) => { + if sink.inner_send(format!("Error: {:?}", e)).is_err() { + errored.push((*conn_id, *sub_id)); + } + } + } + } + + // Remove broken connections + for entry in errored { + log::debug!("Dropping subscription on method: {}, id: {}", self.method, entry.1); subs.remove(&entry); } + Ok(()) } + + /// Consumes the current subscriptions at the given time to get access to the individual subscribers. + /// The [`SubscriptionSink`] will accept new subscriptions after this is called. + // TODO(niklasad1): get rid of this if possible. + pub fn to_sinks(&self) -> impl IntoIterator> { + let mut subs = self.subscribers.lock(); + let sinks = std::mem::take(&mut *subs); + sinks.into_iter().map(|(_, v)| v) + } +} + +/// Represents a single subscription. +pub struct InnerSink { + /// Sink. + sink: mpsc::UnboundedSender, + /// Params. + params: Params, + /// Method. + method: &'static str, + /// Subscription ID. + sub_id: SubscriptionId, +} + +impl InnerSink { + /// Send message on this subscription. + pub fn send(&self, result: &T) -> Result<(), Error> { + let result = to_raw_value(result)?; + self.send_raw_value(&result) + } + + fn send_raw_value(&self, result: &RawValue) -> Result<(), Error> { + let msg = serde_json::to_string(&JsonRpcSubscriptionResponse { + jsonrpc: TwoPointZero, + method: self.method, + params: JsonRpcNotificationParams { subscription: self.sub_id, result: &*result }, + })?; + + self.inner_send(msg).map_err(Into::into) + } + + fn inner_send(&self, msg: String) -> Result<(), Error> { + self.sink.unbounded_send(msg).map_err(|e| Error::Internal(e.into_send_error())) + } + + /// Get params of the subscription. + pub fn params(&self) -> &Params { + &self.params + } } #[cfg(test)] @@ -298,7 +401,7 @@ mod tests { fn rpc_context_modules_can_register_subscriptions() { let cx = (); let mut cxmodule = RpcContextModule::new(cx); - let _subscription = cxmodule.register_subscription("hi", "goodbye"); + let _subscription = cxmodule.register_subscription::<()>("hi", "goodbye"); let methods = cxmodule.into_methods().keys().cloned().collect::>(); assert!(methods.contains(&"hi")); diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 5943a18a03..347b7b5fea 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -27,10 +27,9 @@ use futures_channel::mpsc; use futures_util::io::{BufReader, BufWriter}; use futures_util::stream::StreamExt; -use serde::Serialize; +use serde::{de::DeserializeOwned, Serialize}; use soketto::handshake::{server::Response, Server as SokettoServer}; -use std::net::SocketAddr; -use std::sync::Arc; +use std::{net::SocketAddr, sync::Arc}; use tokio::net::{TcpListener, ToSocketAddrs}; use tokio_stream::wrappers::TcpListenerStream; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -65,14 +64,13 @@ impl Server { } /// Register a new RPC subscription, with subscribe and unsubscribe methods. - pub fn register_subscription( + pub fn register_subscription( &mut self, subscribe_method_name: &'static str, unsubscribe_method_name: &'static str, - ) -> Result { + ) -> Result, Error> { self.root.register_subscription(subscribe_method_name, unsubscribe_method_name) } - /// Register all methods from a module on this server. pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { self.root.merge(module) diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index 1de7642f12..432763829b 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -230,8 +230,8 @@ async fn register_methods_works() { let mut server = WsServer::new("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); assert!(server.register_method("say_hello", |_| Ok("lo")).is_ok()); assert!(server.register_method("say_hello", |_| Ok("lo")).is_err()); - assert!(server.register_subscription("subscribe_hello", "unsubscribe_hello").is_ok()); - assert!(server.register_subscription("subscribe_hello_again", "unsubscribe_hello").is_err()); + assert!(server.register_subscription::<()>("subscribe_hello", "unsubscribe_hello").is_ok()); + assert!(server.register_subscription::<()>("subscribe_hello_again", "unsubscribe_hello").is_err()); assert!( server.register_method("subscribe_hello_again", |_| Ok("lo")).is_ok(), "Failed register_subscription should not have side-effects" @@ -242,7 +242,7 @@ async fn register_methods_works() { async fn register_same_subscribe_unsubscribe_is_err() { let mut server = WsServer::new("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); assert!(matches!( - server.register_subscription("subscribe_hello", "subscribe_hello"), + server.register_subscription::<()>("subscribe_hello", "subscribe_hello"), Err(Error::SubscriptionNameConflict(_)) )); }