ipa: rpi: controller: awb: Add Neural Network AWB

Add an AWB algorithm which uses neural networks.

Signed-off-by: Peter Bailey <peter.bailey@raspberrypi.com>
Reviewed-by: David Plowman <david.plowman@raspberrypi.com>
Reviewed-by: Naushir Patuck <naush@raspberrypi.com>
Signed-off-by: Kieran Bingham <kieran.bingham@ideasonboard.com>
This commit is contained in:
Peter Bailey
2026-01-27 17:13:17 +00:00
committed by Kieran Bingham
parent 6d38984436
commit 045bfb1b8f
3 changed files with 451 additions and 0 deletions

View File

@@ -76,6 +76,11 @@ option('qcam',
value : 'auto',
description : 'Compile the qcam test application')
option('rpi-awb-nn',
type : 'feature',
value : 'auto',
description : 'Enable the Raspberry Pi Neural Network AWB algorithm')
option('test',
type : 'boolean',
value : false,

View File

@@ -32,6 +32,15 @@ rpi_ipa_controller_deps = [
libcamera_private,
]
tflite_dep = dependency('tensorflow-lite', required : get_option('rpi-awb-nn'))
if tflite_dep.found()
rpi_ipa_controller_sources += files([
'rpi/awb_nn.cpp',
])
rpi_ipa_controller_deps += tflite_dep
endif
rpi_ipa_controller_lib = static_library('rpi_ipa_controller', rpi_ipa_controller_sources,
include_directories : libipa_includes,
dependencies : rpi_ipa_controller_deps)

View File

@@ -0,0 +1,437 @@
/* SPDX-License-Identifier: BSD-2-Clause */
/*
* Copyright (C) 2025, Raspberry Pi Ltd
*
* AWB control algorithm using neural network
*
* The AWB Neural Network algorithm can be run entirely with the code here
* and the suppllied TFLite models. Those interested in the full model
* definitions, or who may want to re-train the models should visit
*
* https://github.com/raspberrypi/awb_nn
*
* where you will find full source code for the models, the full datasets
* used for training our supplied models, and full instructions for capturing
* your own images and re-training the models for your own use cases.
*/
#include <algorithm>
#include <libcamera/base/file.h>
#include <libcamera/base/log.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
#include "../awb_algorithm.h"
#include "../awb_status.h"
#include "../lux_status.h"
#include "libipa/pwl.h"
#include "awb.h"
using namespace libcamera;
LOG_DECLARE_CATEGORY(RPiAwb)
constexpr double kDefaultCT = 4500.0;
/*
* The neural networks are trained to work on images rendered at a canonical
* colour temperature. That value is 5000K, which must be reproduced here.
*/
constexpr double kNetworkCanonicalCT = 5000.0;
#define NAME "rpi.nn.awb"
namespace RPiController {
struct AwbNNConfig {
AwbNNConfig() = default;
int read(const libcamera::YamlObject &params, AwbConfig &config);
/* An empty model will check default locations for model.tflite */
std::string model;
float minTemp;
float maxTemp;
bool enableNn;
/* CCM matrix for canonical network CT */
double ccm[9];
};
class AwbNN : public Awb
{
public:
AwbNN(Controller *controller = nullptr);
~AwbNN();
char const *name() const override;
void initialise() override;
int read(const libcamera::YamlObject &params) override;
protected:
void doAwb() override;
void prepareStats() override;
private:
bool isAutoEnabled() const;
AwbNNConfig nnConfig_;
void transverseSearch(double t, double &r, double &b);
RGB processZone(RGB zone, float red_gain, float blue_gain);
void awbNN();
void loadModel();
libcamera::Size zoneSize_;
std::unique_ptr<tflite::FlatBufferModel> model_;
std::unique_ptr<tflite::Interpreter> interpreter_;
};
int AwbNNConfig::read(const libcamera::YamlObject &params, AwbConfig &config)
{
model = params["model"].get<std::string>("");
minTemp = params["min_temp"].get<float>(2800.0);
maxTemp = params["max_temp"].get<float>(7600.0);
for (int i = 0; i < 9; i++)
ccm[i] = params["ccm"][i].get<double>(0.0);
enableNn = params["enable_nn"].get<int>(1);
if (enableNn) {
if (!config.hasCtCurve()) {
LOG(RPiAwb, Error) << "CT curve not specified";
enableNn = false;
}
if (!model.empty() && model.find(".tflite") == std::string::npos) {
LOG(RPiAwb, Error) << "Model must be a .tflite file";
enableNn = false;
}
bool validCcm = true;
for (int i = 0; i < 9; i++)
if (ccm[i] == 0.0)
validCcm = false;
if (!validCcm) {
LOG(RPiAwb, Error) << "CCM not specified or invalid";
enableNn = false;
}
if (!enableNn) {
LOG(RPiAwb, Warning) << "Neural Network AWB misconfigured - switch to Grey method";
}
}
if (!enableNn) {
config.sensitivityR = config.sensitivityB = 1.0;
config.greyWorld = true;
}
return 0;
}
AwbNN::AwbNN(Controller *controller)
: Awb(controller)
{
zoneSize_ = getHardwareConfig().awbRegions;
}
AwbNN::~AwbNN()
{
}
char const *AwbNN::name() const
{
return NAME;
}
int AwbNN::read(const libcamera::YamlObject &params)
{
int ret;
ret = config_.read(params);
if (ret)
return ret;
ret = nnConfig_.read(params, config_);
if (ret)
return ret;
return 0;
}
static bool checkTensorShape(TfLiteTensor *tensor, const int *expectedDims, const int expectedDimsSize)
{
return std::equal(expectedDims, expectedDims + expectedDimsSize,
tensor->dims->data, tensor->dims->data + tensor->dims->size);
}
static std::string buildDimString(const int *dims, const int dimsSize)
{
return "[" + utils::join(Span(dims, dimsSize), ",") + "]";
}
void AwbNN::loadModel()
{
std::string modelPath;
if (getTarget() == "bcm2835") {
modelPath = "/ipa/rpi/vc4/awb_model.tflite";
} else {
modelPath = "/ipa/rpi/pisp/awb_model.tflite";
}
if (nnConfig_.model.empty()) {
std::string root = utils::libcameraSourcePath();
if (!root.empty()) {
modelPath = root + modelPath;
} else {
modelPath = LIBCAMERA_DATA_DIR + modelPath;
}
if (!File::exists(modelPath)) {
LOG(RPiAwb, Error) << "No model file found in standard locations";
nnConfig_.enableNn = false;
return;
}
} else {
modelPath = nnConfig_.model;
}
LOG(RPiAwb, Debug) << "Attempting to load model from: " << modelPath;
model_ = tflite::FlatBufferModel::BuildFromFile(modelPath.c_str());
if (!model_) {
LOG(RPiAwb, Error) << "Failed to load model from " << modelPath;
nnConfig_.enableNn = false;
return;
}
tflite::MutableOpResolver resolver;
tflite::ops::builtin::BuiltinOpResolver builtin_resolver;
resolver.AddAll(builtin_resolver);
tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
if (!interpreter_) {
LOG(RPiAwb, Error) << "Failed to build interpreter for model " << nnConfig_.model;
nnConfig_.enableNn = false;
return;
}
interpreter_->AllocateTensors();
TfLiteTensor *inputTensor = interpreter_->input_tensor(0);
TfLiteTensor *inputLuxTensor = interpreter_->input_tensor(1);
TfLiteTensor *outputTensor = interpreter_->output_tensor(0);
if (!inputTensor || !inputLuxTensor || !outputTensor) {
LOG(RPiAwb, Error) << "Model missing input or output tensor";
nnConfig_.enableNn = false;
return;
}
const int expectedInputDims[] = { 1, (int)zoneSize_.height, (int)zoneSize_.width, 3 };
const int expectedInputLuxDims[] = { 1 };
const int expectedOutputDims[] = { 1 };
if (!checkTensorShape(inputTensor, expectedInputDims, 4)) {
LOG(RPiAwb, Error) << "Model input tensor dimension mismatch. Expected: " << buildDimString(expectedInputDims, 4)
<< ", Got: " << buildDimString(inputTensor->dims->data, inputTensor->dims->size);
nnConfig_.enableNn = false;
return;
}
if (!checkTensorShape(inputLuxTensor, expectedInputLuxDims, 1)) {
LOG(RPiAwb, Error) << "Model input lux tensor dimension mismatch. Expected: " << buildDimString(expectedInputLuxDims, 1)
<< ", Got: " << buildDimString(inputLuxTensor->dims->data, inputLuxTensor->dims->size);
nnConfig_.enableNn = false;
return;
}
if (!checkTensorShape(outputTensor, expectedOutputDims, 1)) {
LOG(RPiAwb, Error) << "Model output tensor dimension mismatch. Expected: " << buildDimString(expectedOutputDims, 1)
<< ", Got: " << buildDimString(outputTensor->dims->data, outputTensor->dims->size);
nnConfig_.enableNn = false;
return;
}
if (inputTensor->type != kTfLiteFloat32 || inputLuxTensor->type != kTfLiteFloat32 || outputTensor->type != kTfLiteFloat32) {
LOG(RPiAwb, Error) << "Model input and output tensors must be float32";
nnConfig_.enableNn = false;
return;
}
LOG(RPiAwb, Info) << "Model loaded successfully from " << modelPath;
LOG(RPiAwb, Debug) << "Model validation successful - Input Image: "
<< buildDimString(expectedInputDims, 4)
<< ", Input Lux: " << buildDimString(expectedInputLuxDims, 1)
<< ", Output: " << buildDimString(expectedOutputDims, 1) << " floats";
}
void AwbNN::initialise()
{
Awb::initialise();
if (nnConfig_.enableNn) {
loadModel();
if (!nnConfig_.enableNn) {
LOG(RPiAwb, Warning) << "Neural Network AWB failed to load - switch to Grey method";
config_.greyWorld = true;
config_.sensitivityR = config_.sensitivityB = 1.0;
}
}
}
void AwbNN::prepareStats()
{
zones_.clear();
/*
* LSC has already been applied to the stats in this pipeline, so stop
* any LSC compensation. We also ignore config_.fast in this version.
*/
generateStats(zones_, statistics_, 0.0, 0.0, getGlobalMetadata(), 0.0, 0.0, 0.0);
/*
* apply sensitivities, so values appear to come from our "canonical"
* sensor.
*/
for (auto &zone : zones_) {
zone.R *= config_.sensitivityR;
zone.B *= config_.sensitivityB;
}
}
void AwbNN::transverseSearch(double t, double &r, double &b)
{
int spanR = -1, spanB = -1;
config_.ctR.eval(t, &spanR);
config_.ctB.eval(t, &spanB);
const int diff = 10;
double rDiff = config_.ctR.eval(t + diff, &spanR) -
config_.ctR.eval(t - diff, &spanR);
double bDiff = config_.ctB.eval(t + diff, &spanB) -
config_.ctB.eval(t - diff, &spanB);
ipa::Pwl::Point transverse({ bDiff, -rDiff });
if (transverse.length2() < 1e-6)
return;
transverse = transverse / transverse.length();
double transverseRange = config_.transverseNeg + config_.transversePos;
const int maxNumDeltas = 12;
int numDeltas = floor(transverseRange * 100 + 0.5) + 1;
numDeltas = std::clamp(numDeltas, 3, maxNumDeltas);
ipa::Pwl::Point points[maxNumDeltas];
int bestPoint = 0;
for (int i = 0; i < numDeltas; i++) {
points[i][0] = -config_.transverseNeg +
(transverseRange * i) / (numDeltas - 1);
ipa::Pwl::Point rbTest = ipa::Pwl::Point({ r, b }) +
transverse * points[i].x();
double rTest = rbTest.x(), bTest = rbTest.y();
double gainR = 1 / rTest, gainB = 1 / bTest;
double delta2Sum = computeDelta2Sum(gainR, gainB, 0.0, 0.0);
points[i][1] = delta2Sum;
if (points[i].y() < points[bestPoint].y())
bestPoint = i;
}
bestPoint = std::clamp(bestPoint, 1, numDeltas - 2);
ipa::Pwl::Point rbBest = ipa::Pwl::Point({ r, b }) +
transverse * interpolateQuadatric(points[bestPoint - 1],
points[bestPoint],
points[bestPoint + 1]);
r = rbBest.x();
b = rbBest.y();
}
AwbNN::RGB AwbNN::processZone(AwbNN::RGB zone, float redGain, float blueGain)
{
/*
* Renders the pixel at canonical network colour temperature
*/
RGB zoneGains = zone;
zoneGains.R *= redGain;
zoneGains.G *= 1.0;
zoneGains.B *= blueGain;
RGB zoneCcm;
zoneCcm.R = nnConfig_.ccm[0] * zoneGains.R + nnConfig_.ccm[1] * zoneGains.G + nnConfig_.ccm[2] * zoneGains.B;
zoneCcm.G = nnConfig_.ccm[3] * zoneGains.R + nnConfig_.ccm[4] * zoneGains.G + nnConfig_.ccm[5] * zoneGains.B;
zoneCcm.B = nnConfig_.ccm[6] * zoneGains.R + nnConfig_.ccm[7] * zoneGains.G + nnConfig_.ccm[8] * zoneGains.B;
return zoneCcm;
}
void AwbNN::awbNN()
{
float *inputData = interpreter_->typed_input_tensor<float>(0);
float *inputLux = interpreter_->typed_input_tensor<float>(1);
float redGain = 1.0 / config_.ctR.eval(kNetworkCanonicalCT);
float blueGain = 1.0 / config_.ctB.eval(kNetworkCanonicalCT);
for (unsigned int i = 0; i < zoneSize_.height; i++) {
for (unsigned int j = 0; j < zoneSize_.width; j++) {
unsigned int zoneIdx = i * zoneSize_.width + j;
RGB processedZone = processZone(zones_[zoneIdx] * (1.0 / 65535), redGain, blueGain);
unsigned int baseIdx = zoneIdx * 3;
inputData[baseIdx + 0] = static_cast<float>(processedZone.R);
inputData[baseIdx + 1] = static_cast<float>(processedZone.G);
inputData[baseIdx + 2] = static_cast<float>(processedZone.B);
}
}
inputLux[0] = static_cast<float>(lux_);
TfLiteStatus status = interpreter_->Invoke();
if (status != kTfLiteOk) {
LOG(RPiAwb, Error) << "Model inference failed with status: " << status;
return;
}
float *outputData = interpreter_->typed_output_tensor<float>(0);
double t = outputData[0];
LOG(RPiAwb, Debug) << "Model output temperature: " << t;
t = std::clamp(t, mode_->ctLo, mode_->ctHi);
double r = config_.ctR.eval(t);
double b = config_.ctB.eval(t);
transverseSearch(t, r, b);
LOG(RPiAwb, Debug) << "After transverse search: Temperature: " << t << " Red gain: " << 1.0 / r << " Blue gain: " << 1.0 / b;
asyncResults_.temperatureK = t;
asyncResults_.gainR = 1.0 / r * config_.sensitivityR;
asyncResults_.gainG = 1.0;
asyncResults_.gainB = 1.0 / b * config_.sensitivityB;
}
void AwbNN::doAwb()
{
prepareStats();
if (zones_.size() == (zoneSize_.width * zoneSize_.height) && nnConfig_.enableNn)
awbNN();
else
awbGrey();
statistics_.reset();
}
/* Register algorithm with the system. */
static Algorithm *create(Controller *controller)
{
return new AwbNN(controller);
}
static RegisterAlgorithm reg(NAME, &create);
} /* namespace RPiController */