Implement your own Lc0 backend in four simple steps:
- Implement
Network
andNetworkComputation
interface. - Write a factory function to create your backend.
- Register your factory function using
REGISTER_NETWORK
macro. - Link your implementation with Lc0.
Some details:
Factory function
The factory function has the following signature:
std::unique_ptr<Network> (const std::optional<WeightsFile>& w,
const OptionsDict& options);
In the case of error the factory function (or any of the backend function actually) my throw an Exception
from utils/exception.h.
The WeightsFile is a Protocol Buffer containing the NN weights.
The OptionsDict contains options passed to the backend through the --backend-opts
command line flag.
Network interface
class Network {
public:
virtual const NetworkCapabilities& GetCapabilities() const = 0;
virtual std::unique_ptr<NetworkComputation> NewComputation() = 0;
};
This interface is typically created only once per session. However the network must not use any global variables or be a singleton because there are exceptions:
-
It’s possible for search to use several backends simultaneously. It typically happens when a user has multiple GPUs. In this case one backend per GPU is created.
-
During the training data generation, it’s possible e.g. to use one backend/weights file for white and another backend for black.
GetCapabilities()
returns a backend capabilities struct:
struct NetworkCapabilities {
InputFormat input_format;
MovesLeftFormat moves_left;
};
InputFormat
should be copied from the WeightsFile
(it defines what data should 112 input planes contain).
MovesLeftFormat
should either be copied from the WeightsFile
if the backend supports that format, or set to MOVES_LEFT_NONE
otherwise.
NewComputation()
function should return a NetworkComputation
object which will evaluate a NN batch. The function must be thread safe. Expect to have at most one instance of NetworkComputation
alive per thread at a given time. Most of libraries don’t simultaneous eval from multiple threads, in this case it makes sense to have a mutex in your Network
class.
NetworkComputation interface
class NetworkComputation {
public:
// Adds a sample to the batch.
virtual void AddInput(InputPlanes&& input) = 0;
// Do the computation.
virtual void ComputeBlocking() = 0;
// Returns how many times AddInput() was called.
virtual int GetBatchSize() const = 0;
// Returns Q value of @sample.
virtual float GetQVal(int sample) const = 0;
virtual float GetDVal(int sample) const = 0;
// Returns P value @move_id of @sample.
virtual float GetPVal(int sample, int move_id) const = 0;
virtual float GetMVal(int sample) const = 0;
virtual ~NetworkComputation() {}
};
The NN eval workflow happens in the following way:
NetworkComputation
(usingYourBackendNetwork::NewComputation()
function)- Encoded chess positions are added into the batch using
AddInput()
function.
Typically batches contain up to 512 entries, but this is configurable. - Batch is evaluated by calling
ComputeBlocking()
function (as the name suggests, this function should block until the eval is completed). - NN output is accessed using
Get*()
functions.
Input encoding
112 input planes are passed in an InputPlanes
type, which is defined as following:
// All input planes are 64 value vectors, every element of which is either
// 0 or some value, unique for the plane. Therefore, input is defined as
// a bitmask showing where to set the value, and the value itself.
struct InputPlane {
InputPlane() = default;
void SetAll() { mask = ~0ull; }
void Fill(float val) {
SetAll();
value = val;
}
std::uint64_t mask = 0ull;
float value = 1.0f;
};
using InputPlanes = std::vector<InputPlane>;
So, it’s a vector of 112 elements. Every element encodes 8×8 plane using InputPlane
struct.
InputPlane
contains 64 bit field mask
which contains the mask of the non-zero elements of 8×8 input matrix, and a float field value
which contains value of those non-zero elements.
Returning output
The Get*()
function all have int sample
as the first parameter, denoting which output of which sample of the batch is requested (zero-based, obviously).
Policy head
Policy head is accessed using GetP(int sample, int move_id)
function. move_id
is an index in the move vector, between 0 and 1857.
Value head
- For VALUE_CLASSICAL:
GetQ()
should return the value of the value head.GetD()
should return 0.0.
- For VALUE_WDL:
- The values of the vector returned from the value head are called W, D and L.
GetQ()
should return W-L.GetD()
should return D.
Moves left head
The output of that head is accessed using GetM()
function.
Example backend implementations
“Real” backends:
Backend adapters to combine other backends (e.g. in the case of multiple GPUs):
Other backends:
- “Check” backend
- Random backend – checks whether two backends return the same result.