Skip to content

Commit

Permalink
Add support for anoma-encode builtin (#2766)
Browse files Browse the repository at this point in the history
This PR adds support for the `anoma-encode` builtin:

```
builtin anoma-encode
axiom anomaEncode : {A : Type} -> A -> Nat
```

In the backend this is compiled to a call to the Anoma / nockma stdlib
`jam` function.

This PR also contains:
* An implementation of the `jam` function in Haskell. This is used in
the Nockma evaluator.
* Unit tests for `jam`
* A benchmark for `jam` applied to the Anoma / nockma stdlib.

Benchmark results:

```
$ juvixbench -p 'Jam'
All
  Nockma
    Jam
      jam stdlib: OK
        109  ms ± 6.2 ms
```
  • Loading branch information
paulcadman authored May 14, 2024
1 parent 6d660f5 commit 1ab94f5
Show file tree
Hide file tree
Showing 35 changed files with 388 additions and 10 deletions.
11 changes: 11 additions & 0 deletions bench2/Benchmark/Nockma.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module Benchmark.Nockma where

import Benchmark.Nockma.Encoding qualified as NockmaEncoding
import Test.Tasty.Bench

bm :: Benchmark
bm =
bgroup
"Nockma"
[ NockmaEncoding.bm
]
21 changes: 21 additions & 0 deletions bench2/Benchmark/Nockma/Encoding.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module Benchmark.Nockma.Encoding where

import Juvix.Compiler.Nockma.Encoding
import Juvix.Compiler.Nockma.Language
import Juvix.Compiler.Nockma.Stdlib (stdlib)
import Juvix.Prelude.Base
import Test.Tasty.Bench

bm :: Benchmark
bm =
bgroup
"Jam"
[bench "jam stdlib" $ nf runJam stdlib]

runJam :: Term Natural -> Natural
runJam =
(^. atom)
. fromRight (error "jam failed")
. run
. runError @NockNaturalNaturalError
. jam
4 changes: 3 additions & 1 deletion bench2/Main.hs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
module Main where

import Benchmark.Effect qualified as Effect
import Benchmark.Nockma qualified as Nockma
import Juvix.Prelude
import Test.Tasty.Bench

main :: IO ()
main =
defaultMain
[ Effect.bm
[ Effect.bm,
Nockma.bm
]
2 changes: 2 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies:
- ansi-terminal == 1.0.*
- base == 4.19.*
- base16-bytestring == 1.0.*
- bitvec == 1.1.*
- blaze-html == 0.9.*
- bytestring == 0.12.*
- cereal == 0.5.*
Expand Down Expand Up @@ -103,6 +104,7 @@ dependencies:
- unordered-containers == 0.2.*
- utf8-string == 1.0.*
- vector == 0.13.*
- vector-builder == 0.3.*
- versions == 6.0.*
- xdg-basedir == 0.2.*
- yaml == 0.11.*
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Asm/Translation/FromTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ genCode fi =
Tree.OpTrace -> mkInstr Trace
Tree.OpFail -> mkInstr Failure
Tree.OpAnomaGet -> impossible
Tree.OpAnomaEncode -> impossible

snocReturn :: Bool -> Code' -> Code'
snocReturn True code = DL.snoc code (mkInstr Return)
Expand Down
13 changes: 13 additions & 0 deletions src/Juvix/Compiler/Builtins/Anoma.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ registerAnomaGet f = do
((ftype ==% (u <>--> u <>--> keyT --> valueT)) freeVars)
(error "anomaGet must be of type {Value Key : Type} -> Key -> Value")
registerBuiltin BuiltinAnomaGet (f ^. axiomName)

registerAnomaEncode :: (Members '[Builtins, NameIdGen] r) => AxiomDef -> Sem r ()
registerAnomaEncode f = do
let ftype = f ^. axiomType
u = ExpressionUniverse smallUniverseNoLoc
l = getLoc f
encodeT <- freshVar l "encodeT"
nat <- getBuiltinName (getLoc f) BuiltinNat
let freeVars = HashSet.fromList [encodeT]
unless
((ftype ==% (u <>--> encodeT --> nat)) freeVars)
(error "anomaEncode must be of type {A : Type} -> A -> Nat")
registerBuiltin BuiltinAnomaEncode (f ^. axiomName)
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Concrete/Data/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ data BuiltinAxiom
| BuiltinIntToString
| BuiltinIntPrint
| BuiltinAnomaGet
| BuiltinAnomaEncode
| BuiltinPoseidon
| BuiltinEcOp
| BuiltinRandomEcPoint
Expand Down Expand Up @@ -223,6 +224,7 @@ instance Pretty BuiltinAxiom where
BuiltinIntToString -> Str.intToString
BuiltinIntPrint -> Str.intPrint
BuiltinAnomaGet -> Str.anomaGet
BuiltinAnomaEncode -> Str.anomaEncode
BuiltinPoseidon -> Str.cairoPoseidon
BuiltinEcOp -> Str.cairoEcOp
BuiltinRandomEcPoint -> Str.cairoRandomEcPoint
Expand Down
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ geval opts herr ctx env0 = eval' env0
OpFail -> failOp
OpTrace -> traceOp
OpAnomaGet -> anomaGetOp
OpAnomaEncode -> anomaEncodeOp
OpPoseidonHash -> poseidonHashOp
OpEc -> ecOp
OpRandomEcPoint -> randomEcPointOp
Expand Down Expand Up @@ -337,6 +338,15 @@ geval opts herr ctx env0 = eval' env0
err "unsupported builtin operation: OpAnomaGet"
{-# INLINE anomaGetOp #-}

anomaEncodeOp :: [Node] -> Node
anomaEncodeOp = unary $ \arg ->
if
| opts ^. evalOptionsNormalize || opts ^. evalOptionsNoFailure ->
mkBuiltinApp' OpAnomaEncode [eval' env arg]
| otherwise ->
err "unsupported builtin operation: OpAnomaGet"
{-# INLINE anomaEncodeOp #-}

poseidonHashOp :: [Node] -> Node
poseidonHashOp = unary $ \arg ->
if
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ builtinOpArgTypes = \case
OpTrace -> [mkDynamic']
OpFail -> [mkTypeString']
OpAnomaGet -> [mkDynamic']
OpAnomaEncode -> [mkDynamic']
OpPoseidonHash -> [mkDynamic']
OpEc -> [mkDynamic', mkTypeField', mkDynamic']
OpRandomEcPoint -> []
Expand Down
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Core/Language/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ data BuiltinOp
| OpTrace
| OpFail
| OpAnomaGet
| OpAnomaEncode
| OpPoseidonHash
| OpEc
| OpRandomEcPoint
Expand Down Expand Up @@ -71,6 +72,7 @@ builtinOpArgsNum = \case
OpTrace -> 1
OpFail -> 1
OpAnomaGet -> 1
OpAnomaEncode -> 1
OpPoseidonHash -> 1
OpEc -> 3
OpRandomEcPoint -> 0
Expand Down Expand Up @@ -108,6 +110,7 @@ builtinIsFoldable = \case
OpTrace -> False
OpFail -> False
OpAnomaGet -> False
OpAnomaEncode -> False
OpPoseidonHash -> False
OpEc -> False
OpRandomEcPoint -> False
Expand All @@ -122,4 +125,4 @@ builtinsCairo :: [BuiltinOp]
builtinsCairo = [OpPoseidonHash, OpEc, OpRandomEcPoint]

builtinsAnoma :: [BuiltinOp]
builtinsAnoma = [OpAnomaGet]
builtinsAnoma = [OpAnomaGet, OpAnomaEncode]
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ instance PrettyCode BuiltinOp where
OpTrace -> return primTrace
OpFail -> return primFail
OpAnomaGet -> return primAnomaGet
OpAnomaEncode -> return primAnomaEncode
OpPoseidonHash -> return primPoseidonHash
OpEc -> return primEc
OpRandomEcPoint -> return primRandomEcPoint
Expand Down Expand Up @@ -801,6 +802,9 @@ primFail = primitive Str.fail_
primAnomaGet :: Doc Ann
primAnomaGet = primitive Str.anomaGet

primAnomaEncode :: Doc Ann
primAnomaEncode = primitive Str.anomaEncode

primPoseidonHash :: Doc Ann
primPoseidonHash = primitive Str.cairoPoseidon

Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ computeNodeTypeInfo md = umapL go
_ -> error "incorrect trace builtin application"
OpFail -> Info.getNodeType node
OpAnomaGet -> Info.getNodeType node
OpAnomaEncode -> Info.getNodeType node
OpPoseidonHash -> case _builtinAppArgs of
[arg] -> Info.getNodeType arg
_ -> error "incorrect poseidon builtin application"
Expand Down
8 changes: 8 additions & 0 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ goAxiomInductive a = whenJust (a ^. Internal.axiomBuiltin) builtinInductive
Internal.BuiltinFieldFromInt -> return ()
Internal.BuiltinFieldToNat -> return ()
Internal.BuiltinAnomaGet -> return ()
Internal.BuiltinAnomaEncode -> return ()
Internal.BuiltinPoseidon -> return ()
Internal.BuiltinEcOp -> return ()
Internal.BuiltinRandomEcPoint -> return ()
Expand Down Expand Up @@ -700,6 +701,12 @@ goAxiomDef a = maybe goAxiomNotBuiltin builtinBody (a ^. Internal.axiomBuiltin)
(mkLambda' (mkVar' 0) (mkBuiltinApp' OpAnomaGet [mkVar' 0]))
)
)
Internal.BuiltinAnomaEncode ->
registerAxiomDef
( mkLambda'
mkSmallUniv
(mkLambda' (mkVar' 0) (mkBuiltinApp' OpAnomaEncode [mkVar' 0]))
)
Internal.BuiltinPoseidon -> do
psName <- getPoseidonStateName
psSym <- getPoseidonStateSymbol
Expand Down Expand Up @@ -1098,6 +1105,7 @@ goApplication a = do
_ -> app
Just Internal.BuiltinFieldToNat -> app
Just Internal.BuiltinAnomaGet -> app
Just Internal.BuiltinAnomaEncode -> app
Just Internal.BuiltinPoseidon -> app
Just Internal.BuiltinEcOp -> app
Just Internal.BuiltinRandomEcPoint -> app
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Translation/Stripped/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ fromCore fsize tab =
BuiltinIntToString -> False
BuiltinIntPrint -> False
BuiltinAnomaGet -> False
BuiltinAnomaEncode -> False
BuiltinPoseidon -> False
BuiltinEcOp -> False
BuiltinRandomEcPoint -> False
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Internal/Translation/FromConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ registerBuiltinAxiom d = \case
BuiltinIntToString -> registerIntToString d
BuiltinIntPrint -> registerIntPrint d
BuiltinAnomaGet -> registerAnomaGet d
BuiltinAnomaEncode -> registerAnomaEncode d
BuiltinPoseidon -> registerPoseidon d
BuiltinEcOp -> registerEcOp d
BuiltinRandomEcPoint -> registerRandomEcPoint d
Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Nockma/Encoding.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module Juvix.Compiler.Nockma.Encoding
( module Juvix.Compiler.Nockma.Encoding.Jam,
)
where

import Juvix.Compiler.Nockma.Encoding.Jam
48 changes: 48 additions & 0 deletions src/Juvix/Compiler/Nockma/Encoding/Base.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module Juvix.Compiler.Nockma.Encoding.Base where

import Data.Bit as Bit
import Data.Bits
import Data.Vector.Unboxed qualified as U
import Juvix.Compiler.Nockma.Encoding.Effect.BitWriter
import Juvix.Prelude.Base

-- | Binary encode an integer to a vector of bits, ordered from least to most significant bits.
-- NB: 0 is encoded as the empty bit vector is specified by the Hoon serialization spec
writeIntegral :: forall a r. (Integral a, Member BitWriter r) => a -> Sem r ()
writeIntegral x
| x < 0 = error "integerToVectorBits: negative integers are not supported in this implementation"
| otherwise = unfoldBits (fromIntegral x)
where
unfoldBits :: Integer -> Sem r ()
unfoldBits n
| n == 0 = return ()
| otherwise = writeBit (Bit (testBit n 0)) <> unfoldBits (n `shiftR` 1)

integerToVectorBits :: (Integral a) => a -> Bit.Vector Bit
integerToVectorBits = run . execBitWriter . writeIntegral

-- | Computes the number of bits required to store the argument in binary
-- NB: 0 is encoded to the empty bit vector (as specified by the Hoon serialization spec), so 0 has bit length 0.
bitLength :: forall a. (Integral a) => a -> Int
bitLength = \case
0 -> 0
n -> go (fromIntegral n) 0
where
go :: Integer -> Int -> Int
go 0 acc = acc
go x acc = go (x `shiftR` 1) (acc + 1)

-- | Decode a vector of bits (ordered from least to most significant bits) to an integer
vectorBitsToInteger :: Bit.Vector Bit -> Integer
vectorBitsToInteger = U.ifoldl' go 0
where
go :: Integer -> Int -> Bit -> Integer
go acc idx (Bit b)
| b = setBit acc idx
| otherwise = acc

-- | Transform a Natural to an Int, computes Nothing if the Natural does not fit in an Int
safeNaturalToInt :: Natural -> Maybe Int
safeNaturalToInt n
| n > fromIntegral (maxBound :: Int) = Nothing
| otherwise = Just (fromIntegral n)
51 changes: 51 additions & 0 deletions src/Juvix/Compiler/Nockma/Encoding/Effect/BitWriter.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module Juvix.Compiler.Nockma.Encoding.Effect.BitWriter where

import Data.Bit as Bit
import Juvix.Prelude.Base
import VectorBuilder.Builder as Builder
import VectorBuilder.Vector

data BitWriter :: Effect where
WriteBit :: Bit -> BitWriter m ()
GetCurrentPosition :: BitWriter m Int

makeSem ''BitWriter

writeOne :: (Member BitWriter r) => Sem r ()
writeOne = writeBit (Bit True)

writeZero :: (Member BitWriter r) => Sem r ()
writeZero = writeBit (Bit False)

newtype WriterState = WriterState
{ _writerStateBuilder :: Builder Bit
}

makeLenses ''WriterState

initWriterState :: WriterState
initWriterState = WriterState {_writerStateBuilder = mempty}

runBitWriter :: forall a r. Sem (BitWriter ': r) a -> Sem r (Bit.Vector Bit, a)
runBitWriter sem = do
(s, res) <- runState initWriterState (re sem)
return (build (s ^. writerStateBuilder), res)

execBitWriter :: forall a r. Sem (BitWriter ': r) a -> Sem r (Bit.Vector Bit)
execBitWriter sem = do
s <- execState initWriterState (re sem)
return (build (s ^. writerStateBuilder))

re :: Sem (BitWriter ': r) a -> Sem (State WriterState ': r) a
re = interpretTop $ \case
WriteBit b -> writeBit' b
GetCurrentPosition -> getCurrentPosition'

writeBit' :: (Member (State WriterState) r) => Bit -> Sem r ()
writeBit' b = modify appendBit
where
appendBit :: WriterState -> WriterState
appendBit = over writerStateBuilder (<> Builder.singleton b)

getCurrentPosition' :: (Member (State WriterState) r) => Sem r Int
getCurrentPosition' = Builder.size <$> gets (^. writerStateBuilder)
Loading

0 comments on commit 1ab94f5

Please sign in to comment.