Skip to content

Commit

Permalink
feat: add basic support for caching
Browse files Browse the repository at this point in the history
  • Loading branch information
tspooner committed Dec 30, 2022
1 parent f6ccbdc commit 9ad7fcc
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 78 deletions.
80 changes: 80 additions & 0 deletions aegir/examples/caching.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#[macro_use]
extern crate aegir;
extern crate rand;

use aegir::{
ids::{X, C},
ops,
Context,
Differentiable,
Function,
Identifier,
Read,
Write,
Node,
};
use std::time;

pub struct Ctx {
pub x: Vec<f64>,
cache: Option<f64>
}

impl Ctx {
pub fn new(x: Vec<f64>) -> Ctx {
Ctx { x, cache: None, }
}
}

impl AsMut<Self> for Ctx {
fn as_mut(&mut self) -> &mut Self { self }
}

impl Context for Ctx {}

impl Read<X> for Ctx {
type Buffer = Vec<f64>;

fn read(&self, _: X) -> Option<Vec<f64>> { Some(self.x.clone()) }
}

impl Read<C> for Ctx {
type Buffer = f64;

fn read(&self, _: C) -> Option<f64> { self.cache.clone() }
}

impl Write<C> for Ctx {
fn write(&mut self, _: C, value: f64) { self.cache.replace(value); }
}

macro_rules! time_op {
($node:ident) => {{
let repeated = $node
.add($node.clone())
.add($node.clone())
.add($node.clone())
.add($node.clone())
.add($node.clone())
.add($node.clone())
.add($node.clone());

let mut ctx = Ctx::new(vec![1.0; 100_000_000]);

let start = time::Instant::now();

repeated.evaluate(&mut ctx).ok();

time::Instant::now().duration_since(start)
}}
}

fn main() {
let x = X.into_var();

let basic = ops::Sum(ops::Negate(ops::Double(x)));
let cached = basic.clone().cached(C);

println!("Basic:\t{:?}", time_op!(basic));
println!("Cached:\t{:?}", time_op!(cached));
}
2 changes: 1 addition & 1 deletion aegir/examples/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn main() {
ctx.y = ctx.x.iter().zip(true_weights.iter()).map(|(x, tw)| x * tw).sum();

// Evaluate gradient:
let g: [f64; N] = adj.evaluate(&ctx).unwrap();
let g: [f64; N] = adj.evaluate(&mut ctx).unwrap();

for i in 0..N {
ctx.w[i] -= 0.01 * g[i];
Expand Down
2 changes: 1 addition & 1 deletion aegir/examples/logistic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fn main() {
ctx.y = rng.gen_bool(ops::sigmoid(yy)) as u8 as f64;

// Evaluate gradient:
let g: [f64; N] = adj.evaluate(&ctx).unwrap();
let g: [f64; N] = adj.evaluate(&mut ctx).unwrap();

for i in 0..N {
ctx.w[i] += 0.0005 * g[i];
Expand Down
36 changes: 28 additions & 8 deletions aegir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ pub mod ids {
}

/// Trait for types that store data [buffers](buffers::Buffer).
pub trait Context: AsRef<Self> {}
pub trait Context: AsMut<Self> {

}

/// Trait for reading entries out of a [Context].
pub trait Read<I: Identifier>: Context {
Expand All @@ -104,6 +106,14 @@ pub trait Read<I: Identifier>: Context {
}
}

pub trait Write<I: Identifier>: Read<I> {
fn write(&mut self, ident: I, val: Self::Buffer);

fn write_spec(&mut self, ident: I, spec: buffers::Spec<Self::Buffer>) {
self.write(ident, spec.unwrap());
}
}

/// Helper macro for defining simple, auto-magical [Context] types.
#[macro_export]
macro_rules! ctx_type {
Expand Down Expand Up @@ -133,6 +143,16 @@ macro_rules! ctx {

/// Base trait for operator nodes.
pub trait Node {
fn cached<I: Identifier>(self, ident: I) -> meta::Cached<Self, I>
where
Self: Sized,
{
meta::Cached {
node: self,
ident,
}
}

fn add<N: Node>(self, other: N) -> ops::Add<Self, N>
where
Self: Sized,
Expand Down Expand Up @@ -243,7 +263,7 @@ pub trait Function<C: Context>: Node {
///
/// assert_eq!(x.evaluate(ctx!{X = 1.0}).unwrap(), 1.0);
/// ```
fn evaluate<CR: AsRef<C>>(&self, ctx: CR) -> AegirResult<Self, C>;
fn evaluate<CR: AsMut<C>>(&self, ctx: CR) -> AegirResult<Self, C>;

/// Evaluate the function and return its lifted [Value](Function::Value).
///
Expand All @@ -257,7 +277,7 @@ pub trait Function<C: Context>: Node {
///
/// assert_eq!(jx.evaluate_spec(ctx!{X = [1.0, 2.0]}).unwrap(), Spec::Diagonal(S2, 1.0));
/// ```
fn evaluate_spec<CR: AsRef<C>>(
fn evaluate_spec<CR: AsMut<C>>(
&self,
ctx: CR,
) -> Result<buffers::Spec<Self::Value>, Self::Error> {
Expand All @@ -270,7 +290,7 @@ pub trait Function<C: Context>: Node {
/// __Note:__ by default, this method performs a full evaluation and calls
/// the shape method on the buffer. This should be overridden in your
/// implementation for better efficiency.
fn evaluate_shape<CR: AsRef<C>>(
fn evaluate_shape<CR: AsMut<C>>(
&self,
ctx: CR,
) -> Result<buffers::shapes::ShapeOf<Self::Value>, Self::Error> {
Expand Down Expand Up @@ -308,7 +328,7 @@ pub trait Differentiable<T: Identifier>: Node {
/// __Note:__ this method can be more efficient than explicitly solving for
/// the adjoint tree. In particular, this method can be implemented
/// using direct numerical calculations.
fn evaluate_adjoint<C: Context, CR: AsRef<C>>(
fn evaluate_adjoint<C: Context, CR: AsMut<C>>(
&self,
target: T,
ctx: CR,
Expand All @@ -322,10 +342,10 @@ pub trait Differentiable<T: Identifier>: Node {

/// Helper method that evaluates the function and its adjoint, wrapping up
/// in a [Dual].
fn evaluate_dual<C: Context, CR: AsRef<C>>(
fn evaluate_dual<C: Context, CR: AsMut<C>>(
&self,
target: T,
ctx: CR,
mut ctx: CR,
) -> Result<
DualOf<Self, C, T>,
BinaryError<Self::Error, <AdjointOf<Self, T> as Function<C>>::Error, NoError>,
Expand All @@ -334,7 +354,7 @@ pub trait Differentiable<T: Identifier>: Node {
Self: Function<C>,
Self::Adjoint: Function<C>,
{
let value = self.evaluate(&ctx).map_err(BinaryError::Left)?;
let value = self.evaluate(&mut ctx).map_err(BinaryError::Left)?;
let adjoint = self.evaluate_adjoint(target, ctx).map_err(BinaryError::Right)?;

Ok(dual!(value, adjoint))
Expand Down
51 changes: 51 additions & 0 deletions aegir/src/meta/cached.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use crate::{
Node,
Write,
Contains,
Function,
Identifier,
AegirResult,
};

#[derive(Copy, Clone, Debug, PartialEq)]
pub struct Cached<N, T> {
pub node: N,
pub ident: T,
}

impl<N: Node, T: Identifier> Node for Cached<N, T> {}

impl<N, T, I> Contains<I> for Cached<N, T>
where
N: Contains<I>,
T: Identifier,
I: Identifier + PartialEq<T>,
{
fn contains(&self, ident: I) -> bool {
ident == self.ident || self.node.contains(ident)
}
}

impl<N, T, C> Function<C> for Cached<N, T>
where
N: Function<C, Value = C::Buffer>,
T: Identifier,
C: Write<T>,
{
type Error = N::Error;
type Value = N::Value;

fn evaluate<CR: AsMut<C>>(&self, mut ctx: CR) -> AegirResult<Self, C> {
let cached = ctx.as_mut().read(self.ident);

if let Some(value) = cached {
return Ok(value);
}

let value = self.node.evaluate(ctx.as_mut())?;

ctx.as_mut().write(self.ident, value.clone());

Ok(value)
}
}
16 changes: 8 additions & 8 deletions aegir/src/meta/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ where
type Error = SourceError<()>;
type Value = S::Buffer;

fn evaluate<CR: AsRef<C>>(&self, ctx: CR) -> Result<S::Buffer, Self::Error> {
fn evaluate<CR: AsMut<C>>(&self, ctx: CR) -> Result<S::Buffer, Self::Error> {
self.evaluate_spec(ctx).map(|spec| spec.unwrap())
}

fn evaluate_spec<CR: AsRef<C>>(&self, _: CR) -> Result<Spec<S::Buffer>, Self::Error> {
fn evaluate_spec<CR: AsMut<C>>(&self, _: CR) -> Result<Spec<S::Buffer>, Self::Error> {
Ok(self.0.clone().into_spec())
}

fn evaluate_shape<CR: AsRef<C>>(&self, _: CR) -> Result<S::Shape, Self::Error> {
fn evaluate_shape<CR: AsMut<C>>(&self, _: CR) -> Result<S::Shape, Self::Error> {
Ok(self.0.shape())
}
}
Expand Down Expand Up @@ -162,21 +162,21 @@ where
type Error = crate::BinaryError<N::Error, SourceError<T>, crate::NoError>;
type Value = <CA as Class<SA>>::Buffer<F>;

fn evaluate<CR: AsRef<C>>(&self, ctx: CR) -> Result<Self::Value, Self::Error> {
fn evaluate<CR: AsMut<C>>(&self, ctx: CR) -> Result<Self::Value, Self::Error> {
self.evaluate_spec(ctx).map(|lifted| lifted.unwrap())
}

fn evaluate_spec<CR: AsRef<C>>(&self, ctx: CR) -> Result<Spec<Self::Value>, Self::Error> {
fn evaluate_spec<CR: AsMut<C>>(&self, ctx: CR) -> Result<Spec<Self::Value>, Self::Error> {
self.evaluate_shape(ctx)
.map(|shape| Spec::Full(shape, F::zero()))
}

fn evaluate_shape<CR: AsRef<C>>(&self, ctx: CR) -> Result<SA, Self::Error> {
fn evaluate_shape<CR: AsMut<C>>(&self, mut ctx: CR) -> Result<SA, Self::Error> {
let shape_value = self
.node
.evaluate_shape(ctx.as_ref())
.evaluate_shape(&mut ctx)
.map_err(crate::BinaryError::Left)?;
let shape_target = ctx.as_ref().read(self.target).map(|buf| buf.shape()).ok_or(
let shape_target = ctx.as_mut().read(self.target).map(|buf| buf.shape()).ok_or(
crate::BinaryError::Right(SourceError::Undefined(self.target)),
)?;

Expand Down
3 changes: 3 additions & 0 deletions aegir/src/meta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@ pub use self::constant::{Constant, ConstantAdjoint};

mod variable;
pub use self::variable::{Variable, VariableAdjoint};

mod cached;
pub use self::cached::Cached;
46 changes: 23 additions & 23 deletions aegir/src/meta/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@ where
type Error = SourceError<I>;
type Value = C::Buffer;

fn evaluate<CR: AsRef<C>>(&self, ctx: CR) -> Result<Self::Value, Self::Error> {
ctx.as_ref()
fn evaluate<CR: AsMut<C>>(&self, mut ctx: CR) -> Result<Self::Value, Self::Error> {
ctx.as_mut()
.read(self.0)
.ok_or_else(|| SourceError::Undefined(self.0))
}

fn evaluate_spec<CR: AsRef<C>>(&self, ctx: CR) -> Result<Spec<Self::Value>, Self::Error> {
ctx.as_ref()
fn evaluate_spec<CR: AsMut<C>>(&self, mut ctx: CR) -> Result<Spec<Self::Value>, Self::Error> {
ctx.as_mut()
.read_spec(self.0)
.ok_or_else(|| SourceError::Undefined(self.0))
}

fn evaluate_shape<CR: AsRef<C>>(&self, ctx: CR) -> Result<ShapeOf<Self::Value>, Self::Error> {
ctx.as_ref()
fn evaluate_shape<CR: AsMut<C>>(&self, mut ctx: CR) -> Result<ShapeOf<Self::Value>, Self::Error> {
ctx.as_mut()
.read_shape(self.0)
.ok_or_else(|| SourceError::Undefined(self.0))
}
Expand Down Expand Up @@ -158,21 +158,21 @@ where
type Error = crate::BinaryError<SourceError<I>, SourceError<T>, crate::NoError>;
type Value = <CA as Class<SA>>::Buffer<F>;

fn evaluate<CR: AsRef<C>>(&self, ctx: CR) -> Result<Self::Value, Self::Error> {
fn evaluate<CR: AsMut<C>>(&self, ctx: CR) -> Result<Self::Value, Self::Error> {
self.evaluate_spec(ctx).map(|spec| spec.unwrap())
}

fn evaluate_spec<CR: AsRef<C>>(&self, ctx: CR) -> Result<Spec<Self::Value>, Self::Error> {
fn evaluate_spec<CR: AsMut<C>>(&self, mut ctx: CR) -> Result<Spec<Self::Value>, Self::Error> {
let shape_value = ctx
.as_ref()
.as_mut()
.read_shape(self.value)
.ok_or(crate::BinaryError::Left(SourceError::Undefined(self.value)))?;
let shape_target =
ctx.as_ref()
.read_shape(self.target)
.ok_or(crate::BinaryError::Right(SourceError::Undefined(
self.target,
)))?;
let shape_target = ctx
.as_mut()
.read_shape(self.target)
.ok_or(crate::BinaryError::Right(SourceError::Undefined(
self.target,
)))?;
let shape_adjoint = shape_value.concat(shape_target);

Ok(if self.value == self.target {
Expand Down Expand Up @@ -200,17 +200,17 @@ where
})
}

fn evaluate_shape<CR: AsRef<C>>(&self, ctx: CR) -> Result<SA, Self::Error> {
fn evaluate_shape<CR: AsMut<C>>(&self, mut ctx: CR) -> Result<SA, Self::Error> {
let shape_value = ctx
.as_ref()
.as_mut()
.read_shape(self.value)
.ok_or(crate::BinaryError::Left(SourceError::Undefined(self.value)))?;
let shape_target =
ctx.as_ref()
.read_shape(self.target)
.ok_or(crate::BinaryError::Right(SourceError::Undefined(
self.target,
)))?;
let shape_target = ctx
.as_mut()
.read_shape(self.target)
.ok_or(crate::BinaryError::Right(SourceError::Undefined(
self.target,
)))?;

Ok(shape_value.concat(shape_target))
}
Expand Down
Loading

0 comments on commit 9ad7fcc

Please sign in to comment.