Lc0 XLA backend uses OpenXLA compiler to produce code that executes a neural network. The backend takes ONNX as an input, and converts it to the HLO format that XLA understands.
The C API that OpenXLA implements is called PjRT, and the Lc0 XLA backend runs NN through it.
Pros
- On NVidia devices, it runs NN at speeds comparable to the handwritten cuda/cudnn backend (and sometimes faster).
- Only requires ONNX as input (either external, or converted by the Lc0 ONNX converter), meaning it can run NNs for which other backends are no written yet.
- Optimizes for particular GPU (will produce different code for A100 vs H100, and for different memory sizes)
- Supports running on TPUs
- Supports multiGPU / multihost setups for huge nets (not sure if we ever need it)
Cons
- Painful to build
- Linux only (in theory runs in windows through WSL)
- Slow backend startup. It compiles networks when it load them, which may take a few minutes.
Building XLA Lc0 backend
To build it, just add -Dxla
parameter to ./build.sh
. There’s no dependencies, but it needs pjrt
plugin (a .so
file) at the runtime.
Obtaining XLA PjRT plugin (a .so
file)
Using prebuilt one
Most likely, you’ll have to build it yourself, but you may try prebuilt XLA library in the elixir-nx repository. They build entire XLA repository though, so the resulting .so
file is ≈20% larger, and noone tested yet whether the function GetPjrtApi()
is actually exported.
Alternatively, you can use the prebuilt plugin that is shipped with Jax on CUDA-enabled versions.
- Enter a Python environment. e.g.,
python3 -m virtualenv venv
;source venv/bin/activate
. - Install the Jax backend (PJRT) library
pip install jax-cuda12-pjrt
(cuda11 also available). - Copy and rename file found in
venv/lib/python3.10/site-packages/jax_plugins/xla_cuda12/xla_cuda_plugin.so
. Note that you may need to adapt the path. You may rename the file topjrt_c_api_gpu_plugin.so
. - As of writing, the symbol
GetPjrtApi
is confirmed to be exported, so you should be able to use this prebuilt library.
Building it yourself
- Clone the openxla/xla repo:
git clone https://github.com/openxla/xla.git
-
Install either Bazel 6.5.0, or (better) Bazelisk, which would install correct version of Bazel automatically.
-
Inside the repo, run
./configure
with desired parameters (see./configure --help
), for example:
TF_CUDA_PATHS=/opt/cuda,/usr ./configure.py --backend CUDA --cudnn_version=8 --cuda_compute_capabilities=5.2
- Build the plugin (takes some hours):
bazel build -c opt //xla/pjrt/c:pjrt_c_api_gpu_plugin.so
- After that, the resulting plugin will be locates at
<xla-repo>/bazel-bin/xla/pjrt/c/pjrt_c_api_gpu_plugin.so
. You’ll either have to put it into the same directory as your./lc0
, or will have to specify path to it in the Lc0 backend opts.
Obtaining PjRT plugin to run on Google TPU
The PjRT XLA plugin to run on Google TPU is called libtpu.so
. Noone tried it though.
It’s available at https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/2023-09-12/libtpu.so (or check for newer versions here)
Running XLA backend
Specify --backend=xla
as usual. Supported options (through --backend-opts
):
device=0
– use GPU number 0 (it’sdevice
rather thangpu
because it’s not necessary a GPU but rather e.g. TPU)plugin_path=./pjrt_c_api_gpu_plugin.so
path to the PjRT plugin. Note that the filename usesdlopen()
convention: if it contains slash (e.g. starts with./
, it’s a relative or absolute path to a file); if it doesn’t contain a slash, the only a few system defined locations (e.g. LD_LIBRARY_PATH dir) are checked.data_type=f32
the data type to use inside the NN (currently supported aref32
,f16
,bf16
). (currently, in all cases the backend communicates with the GPU usingf32
though)max_batch=512
andsteps=16
define for which batch sizes to compile the code. Currently XLA backend only supports static batch sizes, and chooses smallest suitable batch size for evaluation. E.g. ifmax_batch=9
andsteps=3
, kernels for batch sizes 3, 6 and 9 will be generated.
Updating/Debugging the XLA backend
To inspect the generated HLO, a bunch of options are added to the leela2onnx
command, for example:
--hlo-text-output=filename.txt
– text dump of HLO--hlo-proto-output=filename.pb
– dump of HLO as proto. This can be fed to all the XLA tools (e.g.hlo-opt
to see the kernels it generates)
If you need to add a convertor for an ONNX op not currently supported, (or not fully supported), here is how it can be done:
- Convert your net to ONNX, e.g.
./lc0 leela2onnx --input=t79.pb.gz --output=t79.onnx
- Convert your ONNX net to HLO (in
text
andproto
form) using thejaxonnxruntime
. You’ll need.hlo
file to see how it’s implemented in HLO form (for theonnx2hlo.cc
code), and.hlo.proto
to find out how new HLO opcodes (if any) are encoded in proto (for thehlo_builder.cc
code):
import onnx
from jaxonnxruntime.call_onnx import call_onnx_model
from jaxonnxruntime import config
import numpy as np
import jax
config.update('jaxort_only_allow_initializers_as_static_args', False)
MODEL_FILE = 't79.onnx'
BATCH_SIZE = 14
INPUTS = [np.zeros((BATCH_SIZE, 112, 8, 8))]
onnx_model = onnx.load(MODEL_FILE)
m = call_onnx_model(onnx_model, INPUTS)
lowered = jax.jit(m[0]).lower(m[1], INPUTS)
with open(MODEL_FILE + ".hlo", "w") as f:
f.write(lowered.compiler_ir('hlo').as_hlo_text())
with open(MODEL_FILE + ".hlo.proto", "wb") as f:
f.write(lowered.compiler_ir('hlo').as_serialized_hlo_module_proto())
- Convert ONNX and HLO proto to text format: ONNX:
cat t75.onnx | protoc -I ~/path/to/lc0/src/neural/onnx/ ~/path/to/lc0/src/neural/onnx/onnx.proto --decode=pblczero.ModelProto > t75.onnx.asciiproto
HLO:
cat t75.hlo.proto | protoc -I ~/path/to/xla/ ~/path/to/xla/xla/service/hlo.proto --decode=xla.HloModuleProto > t75.hlo.asciiproto
- Add missing HLO opcodes to
hlo_builder.cc
and ONNX operands toonnx2hlo.cc
. When adding new HLO opcodes, it’s often needed to add more fields intohlo.proto
. When that happens, also updateprint_hlo.cc
to show them.
Also see the description of the initial PR for more details.
- To iterate on changes, it’s useful to use (hidden)
--hlo-allow-partial-result
flag ofleela2onnx
, to see the partial output of the HLO conversion even if there is an error, for example:
./lc0 leela2onnx --input=BT3-768x15x24h-swa-2790000.pb.gz --hlo-text-output=- --hlo-allow-partial-result --hlo-batch-size=333 --onnx-data-type=bf16