Installing JAX on the Hailo-8 platform requires a comprehensive understanding of both the JAX framework and the Hailo-8 hardware/software stack. The Hailo-8 is a specialized AI accelerator designed for edge devices, optimized for running deep learning inference tasks with high efficiency and low power consumption. JAX, developed by Google, is a Python library for high-performance numerical computing, supporting transformations like automatic differentiation and just-in-time (JIT) compilation. JAX is primarily designed to target CPUs, GPUs (via CUDA), and TPUs, but not specialized accelerators like Hailo-8 out-of-the-box. Integrating JAX with Hailo-8 thus involves a multi-stage process, including model development in JAX, exporting models, and deploying them to Hailo-8 using its dedicated SDK and toolchain.
Understanding the Stack: JAX and Hailo-8
JAX operates by transforming Python and NumPy code and compiling it to run efficiently on supported hardware backends, primarily using XLA (Accelerated Linear Algebra compiler). Hailo-8, on the other hand, is not natively supported by the XLA backend, nor does it provide a direct Python runtime environment for executing JAX code. Instead, the typical workflow involves:
1. Training and developing models in a high-level framework.
2. Exporting models to a standard intermediate representation (often ONNX or TensorFlow SavedModel).
3. Converting and compiling the exported model with the Hailo Software Suite for deployment on Hailo-8 hardware.
Detailed Workflow for Using JAX Models on Hailo-8
1. Develop and Train the Model Using JAX
JAX is commonly used for developing and training neural network models, often in conjunction with libraries such as Flax or Haiku for higher-level abstractions. For instance, training a convolutional neural network (CNN) on the MNIST dataset in JAX/Flax might look like this:
python
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
class SimpleCNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # Flatten
x = nn.Dense(features=10)(x)
return x
# Model training code...
Model training and evaluation are performed on CPUs or GPUs using JAX's fast primitives.
2. Export the JAX Model
JAX does not directly support exporting models to ONNX or TensorFlow SavedModel. However, you can convert a trained JAX model into a format compatible with Hailo-8’s toolchain by following these steps:
– Reimplement the model in TensorFlow or PyTorch: Since Hailo-8 supports ONNX and TensorFlow, re-implement the trained model’s architecture in TensorFlow or PyTorch and port the weights.
– Export to ONNX or TensorFlow SavedModel: If you use TensorFlow, you can save the model as a SavedModel. If using PyTorch, export the model to ONNX.
For example, to export a PyTorch model to ONNX:
{{EJS7}}3. Convert and Compile the Model Using Hailo’s Software Suite
The Hailo Software Suite (HailoRT, Hailo Model Zoo, and Hailo TappAS) supports importing models in ONNX and TensorFlow formats. The workflow involves:
- Import the Model: Use Hailo TappAS or the command-line tools to import the ONNX or TensorFlow model.
bash
hailomodels import --model-path model.onnx --output-path my_hailo_model.hef
- Profile and Quantize the Model: Hailo-8 operates with quantized (INT8) models for efficiency. The Software Suite includes profiling and quantization tools to convert the floating-point model into a quantized format suitable for Hailo-8.
bash
hailomodels profile --model-path my_hailo_model.hef --data-path calibration_dataset/
hailomodels quantize --model-path my_hailo_model.hef --profile-path profile_results/
- Compile the Model: The quantized model is then compiled into a Hailo Executable File (.hef) which can be deployed to the Hailo-8 accelerator.
{{EJS10}}
4. Deploy the Model to Hailo-8
Once the `.hef` file is generated, it can be deployed to a device containing the Hailo-8 AI processor. The deployment process typically involves:
- Transferring the `.hef` file to the target device.
- Using HailoRT or Hailo’s runtime API (Python or C++) to load and run inference on the model.
Example code for inference using HailoRT’s Python API:
{{EJS11}}
5. Validation and Optimization
After deployment, validate the model’s accuracy and throughput on the Hailo-8 device. Due to quantization, minor differences in accuracy compared to the floating-point (JAX) model may occur. Hailo’s tools provide options for further optimization, such as layer fusion, batch processing, and custom post-processing steps.
Key Points and Didactic Value
Compatibility and Portability: JAX models are not directly executable on Hailo-8 hardware. The workflow requires exporting the model parameters and architecture to a supported format (ONNX or TensorFlow). This process typically involves reimplementation or conversion, as direct JAX-to-ONNX export is not natively supported.
Toolchain Use: Familiarity with Hailo’s Software Suite is necessary. The tools provide a structured process for profiling, quantization, and compilation, allowing efficient deployment of deep learning models on the edge.
Model Quantization: Hailo-8 accelerates INT8-quantized models. The quantization step is critical for maximizing inference performance and minimizing resource consumption. Calibration datasets representative of real-world input are important for maintaining post-quantization accuracy.
Limitations and Workarounds: Direct execution of JAX code on Hailo-8 is not feasible due to the lack of a compatible backend. The workaround involves using JAX for model development and training, followed by migration to a Hailo-compatible framework for deployment.
Example: End-to-End Workflow
Consider deploying a simple image classification model trained in JAX to a Hailo-8 device. The high-level steps are:
1. Train in JAX: Use Flax to train a CNN on your image dataset.
2. Reimplement in TensorFlow: Using the same architecture, create a TensorFlow version of the model.
3. Transfer Weights: Extract the trained weights from the JAX model and load them into the TensorFlow model. This may involve manual mapping or conversion scripts.
4. Export to ONNX/SavedModel: Save the TensorFlow model as a SavedModel or export to ONNX.
5. Use Hailo TappAS: Import, quantize, and compile the model into a `.hef` file using Hailo’s tools.
6. Deploy and Run Inference: Load the `.hef` file onto the Hailo-8 device and run inference using HailoRT.
Additional Considerations
Batch Size and Latency: Hailo-8 supports configurable batch sizes that influence throughput and latency. Profiling with different batch sizes can identify optimal operational points for the application.
Model Support: Hailo provides a Model Zoo with reference models and examples. It is advisable to consult these resources for supported architectures and best practices.
Supported Operators: Not all neural network layers and operations are supported by Hailo-8. Review the Hailo documentation to ensure that your JAX-derived model uses compatible operations. Unsupported operations may require substitution or redesign.
Software Versions: Ensure compatibility between the versions of TensorFlow/PyTorch used for export and the Hailo Software Suite.
Automation: For large-scale deployment or frequent updates, consider scripting the conversion and deployment process using Hailo’s command-line tools and Python APIs.
References and Official Resources
- [JAX Documentation](https://jax.readthedocs.io)
- [Hailo Software Suite Documentation](https://docs.hailo.ai)
- [ONNX Documentation](https://onnx.ai)
- [TensorFlow SavedModel Guide](https://www.tensorflow.org/guide/saved_model)
- [Hailo Model Zoo](https://github.com/hailo-ai/ModelZoo)
This workflow enables leveraging JAX for rapid prototyping and training of models, followed by efficient deployment on the Hailo-8 AI accelerator for edge inference scenarios.
Other recent questions and answers regarding Introduction to JAX:
- Why is JAX faster than NumPy?
- How does JAX handle training deep neural networks on large datasets using the vmap function?
- What are the features of JAX that allow for maximum performance in the Python environment?
- How does JAX leverage XLA to achieve accelerated performance?
- What are the two modes of differentiation supported by JAX?
- What is JAX and how does it speed up machine learning tasks?

