The 🔥 Deep Learning Framework
] add Lux
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, LuxAMDGPU # Optional packages for GPU support
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
Chain(Dense(256, 1, tanh), Dense(1, 10)))
# Get the device determined by Lux
device = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
# Dummy Input
x = rand(rng, Float32, 128, 2) |> device
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
gs = gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)[1]
# Optimization
st_opt = Optimisers.setup(Optimisers.Adam(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)
Look in the examples directory for self-contained usage examples. The documentation has examples sorted into proper categories.
Checkout our Ecosystem page for more details.
For usage related questions, please use Github Discussions or JuliaLang Discourse (machine learning domain) which allows questions and answers to be indexed. To report bugs use github issues or even better send in a pull request.
Structure of the packages part of the Lux.jl
Universe1: (Rounded Rectangles denote packages maintained by Lux.jl
developers)
flowchart LR
subgraph Interface
LuxCore(LuxCore)
end
subgraph Backend
LuxLib(LuxLib)
NNlib
CUDA
end
subgraph ExternalML[External ML Packages]
Flux
Metalhead
end
subgraph CompViz[Computer Vision]
Boltz(Boltz)
end
subgraph SciML[Scientific Machine Learning]
DeepEquilibriumNetworks(DeepEquilibriumNetworks)
DiffEqFlux(DiffEqFlux)
NeuralPDE[Neural PDE: PINNs]
end
subgraph AD[Automatic Differentiation]
Zygote
Enzyme["Enzyme (experimental)"]
end
subgraph Dist[Distributed Training]
FluxMPI(FluxMPI)
end
subgraph SerializeModels[Serialize Models]
Serial[Serialization]
JLD2
BSON
end
subgraph Opt[Optimization]
Optimisers
Optimization
end
subgraph Parameters
ComponentArrays
end
Lux(Lux)
Parameters --> Lux
LuxCore --> Lux
Backend --> Lux
Lux --> SciML
AD --> Lux
Lux --> Dist
Lux --> SerializeModels
Lux --> Opt
Lux --> CompViz
ExternalML -.-> CompViz
- Flux.jl -- We share most of the backend infrastructure with Flux (Roadmap hints towards making Flux explicit-parameter first)
- Knet.jl -- One of the mature and OG Julia Deep Learning Frameworks
- SimpleChains.jl -- Extremely Efficient for Small Neural Networks on CPU
- Avalon.jl -- Uses tracing based AD Yota.jl
If you found this library to be useful in academic work, then please cite:
@software{pal2023lux,
author = {Pal, Avik},
title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
month = apr,
year = 2023,
note = {If you use this software, please cite it as below.},
publisher = {Zenodo},
version = {v0.5.0},
doi = {10.5281/zenodo.7808904},
url = {https://doi.org/10.5281/zenodo.7808904}
}
@thesis{pal2023efficient,
title = {{On Efficient Training \& Inference of Neural Differential Equations}},
author = {Pal, Avik},
year = {2023},
school = {Massachusetts Institute of Technology}
}
Also consider starring our github repo
Footnotes
-
These packages only constitute a subset of the ecosystem. Specifically these are the packages which the maintainers of Lux.jl have personally tested out. If you want a new package to be listed here, please open an issue. ↩