#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["HybridModel"]
import akida
import onnx
import os
import math
import onnxruntime
import numpy as np
from collections import defaultdict
import onnx_ir as ir
from quantizeml.onnx_support.layers import BRN_OPSET, ONNX_OPSET
def _return_element_or_list(list_object):
if len(list_object) == 1:
return list_object[0]
return list_object
def _get_ir_dimensions(values):
# Remove batch dimension.
return _return_element_or_list([list(v.shape[1:]) for v in values])
def _find_nodes_between_values(input_value, output_value):
nodes = []
queue = [output_value]
while len(queue) > 0:
out_value = queue.pop(0)
node = out_value.producer()
if node in nodes or node is None:
continue
nodes.append(node)
if any(in_value == input_value for in_value in node.inputs):
continue
queue.extend(node.inputs)
return nodes
def _print_table(table, title, new_splits=None):
"""Helper method to print formatted tables.
Args:
table (list): 2D list containing table data (including headers).
title (str): title of the table.
new_splits (list, optional): list indicating where to insert separators. Defaults to
None.
"""
# Convert to np.array
to_np = np.vectorize(str, otypes=['O'])
table = to_np(table)
# Get column lengths
str_len_f = np.vectorize(lambda cell: len(str(cell)))
str_lens = np.amax(str_len_f(table), 0)
line_len = np.sum(str_lens)
# Prepare format rows
size_formats = np.vectorize(lambda cell: f"{{:{cell}.{cell}}}")
format_strings = size_formats(str_lens)
format_row = " ".join(format_strings)
# Generate separators
separator_len = line_len + 2 * len(table[0])
separator = "_" * separator_len
double_separator = "=" * separator_len
# Print header
center_format = f"{{:^{separator_len}}}"
if title is not None:
print(center_format.format(title))
print(separator)
print(format_row.format(*table[0]))
rows = table[1:, :]
if new_splits is None:
new_splits = [False] * len(rows)
assert len(rows) == len(new_splits)
if not any(new_splits):
print(double_separator)
# Print body
for row, new_split in zip(rows, new_splits):
if isinstance(new_split, str):
print()
# Compute the number of char on each side of the text
space_len = max((separator_len - len(new_split)) / 2., 1.)
space_left = "=" * int(np.ceil(space_len - 1))
space_right = "=" * int(np.floor(space_len - 1))
print(space_left, new_split, space_right)
print()
elif new_split is False and row[0] != rows[0][0]:
print(separator)
print(format_row.format(*row))
print(separator)
def get_ir_input_dtype(ak_layer):
"""Determines the appropriate ONNX IR input data type for a given Akida layer.
Args:
ak_layer (akida.Layer): the Akida layer to determine the input data type.
Returns:
ir.TensorType: the corresponding ONNX IR tensor type.
"""
layer_params = ak_layer.parameters
if ((layer_params.layer_type == akida.LayerType.InputData and layer_params.input_signed)
or layer_params.layer_type != akida.LayerType.InputConv2D):
return ir.TensorType(ir.DataType.INT8)
return ir.TensorType(ir.DataType.UINT8)
def get_ir_output_dtype(ak_layer):
"""Determines the appropriate ONNX IR output data type for a given Akida layer.
Args:
ak_layer (akida.Layer): the Akida layer to determine the output data type.
Returns:
ir.TensorType: the corresponding ONNX IR tensor type.
"""
layer_params = ak_layer.parameters
if layer_params.layer_type == akida.LayerType.InputData:
return get_ir_input_dtype(ak_layer)
if layer_params.layer_type == akida.LayerType.Dequantizer:
return ir.TensorType(ir.DataType.FLOAT)
return ir.TensorType(ir.DataType.INT8 if layer_params.output_bits <= 8 else ir.DataType.INT32)
def convert_ak_model_into_onnx(ak_model,
program_path,
ir_input,
in_channel_last=True,
out_channel_last=True,
flat_output=False):
"""Converts an Akida model into its equivalent ONNX operators.
This function wraps the Akida model as a list of ONNX node AkidaOp operators,
preserving input and output shapes and data types.
Args:
ak_model (akida.Model): the Akida model to convert.
program_path (str): path where Akida program(s) will be saved to be referenced
by the ONNX node.
ir_input (onnxscript.ir.Value): value used as onnx model input.
in_channel_last (bool, optional): whether the input shape is channel-last format.
Since Akida requires channel last, include a Transpose in the input when it is False.
Defaults to True.
out_channel_last (bool, optional): whether the output shape is channel-last format.
Append a Transpose in the output when it is False. Defaults to True.
flat_output (bool, optional): whether to flatten the output to 2D shape.
Defaults to False.
Returns:
ir.Tape: the registered operations tape.
"""
program_base, ext = os.path.splitext(program_path)
assert ext == ".bin", f"Wrong extension in {program_path}. It must be '.bin'."
if not isinstance(ak_model.device, akida.HardwareDevice):
raise ValueError("Model must be mapped on a physical HardwareDevice "
f"(virtual devices cannot be used). Current device: {ak_model.device}.")
# Save program(s) from the Akida model.
program_path_and_layer = []
for idx, sequence in enumerate(ak_model.sequences):
if sequence.backend != akida.BackendType.Hardware:
raise RuntimeError(f"Impossible to extract the program(s): sequence {idx} of model "
"has not been mapped on a hardware-based backend. "
f"Current backend: {sequence.backend}.")
if len(ak_model.sequences) > 1:
program_path = program_base + f"_{idx}" + ext
program_path_and_layer.insert(idx, (program_path, sequence.passes[-1].layers[-1]))
with open(program_path, "wb") as f:
f.write(sequence.program)
# Define input/output.
if get_ir_output_dtype(ak_model.layers[-1]).dtype == ir.DataType.FLOAT:
# We reject this because Dequantizer must run in CPU.
raise RuntimeError("Cannot convert models whose output is float.")
# Create tape to record ops.
tape = ir.tape.Tape()
# Convert input from channel-first to channel-last.
y = ir_input
if not in_channel_last:
if len(ir_input.shape) == 2:
# When input is 2D, we need to expand the input to 4D.
y = ir.Value(name=f"{y.name}/transp",
shape=ir.Shape((y.shape[0], 1, 1, y.shape[1])),
type=ir.TensorType(y.dtype))
reshape = tape.initializer(ir.tensor((0, 1, 1, y.shape[-1]),
dtype=ir.DataType.INT64,
name=f"{y.name}/reshape_shape"))
tape.op(op_type="Reshape",
inputs=[ir_input, reshape],
version=ONNX_OPSET.version,
output=y,
metadata_props={"tag": "added on conversion"})
else:
y = ir.Value(name=f"{y.name}/transp",
shape=ir.Shape((y.shape[0], *y.shape[2:], y.shape[1])),
type=ir.TensorType(y.dtype))
tape.op(op_type="Transpose",
inputs=[ir_input],
attributes={"perm": [0, 2, 3, 1]},
version=ONNX_OPSET.version,
output=y,
metadata_props={"tag": "added on conversion"})
# Record the Akida op(s).
for program_path, ak_layer in program_path_and_layer:
x = y
op_type = "AkidaOpInt8" if ak_layer.parameters.output_bits <= 8 else "AkidaOpInt32"
y = ir.Value(name=f"{ak_layer.name}/ak_op",
shape=ir.Shape((x.shape[0], *ak_layer.output_dims)),
type=get_ir_output_dtype(ak_layer))
tape.op(op_type=op_type,
inputs=[x],
attributes={"program_path": ir.AttrString("program_path", str(program_path))},
name=ak_layer.name,
domain=BRN_OPSET.domain,
version=BRN_OPSET.version,
output=y)
# Convert output from channel-last to channel-first.
if not out_channel_last:
x = y
y = ir.Value(name=f"{x.name}/transp",
shape=ir.Shape((x.shape[0], x.shape[-1], *x.shape[1:-1])),
type=ir.TensorType(x.dtype))
tape.op(op_type="Transpose",
inputs=[x],
attributes={"perm": [0, 3, 1, 2]},
version=ONNX_OPSET.version,
output=y,
metadata_props={"tag": "added on conversion"})
# Flat output if expected one is 2D.
if flat_output:
x = y
y = ir.Value(name=f"{x.name}/flat",
shape=ir.Shape((x.shape[0], math.prod(x.shape[1:]))),
type=ir.TensorType(x.dtype))
tape.op(op_type="Flatten",
inputs=[x],
version=ONNX_OPSET.version,
output=y,
metadata_props={"tag": "added on conversion"})
return tape
[docs]
class HybridModel:
"""Tensor-driven container for mixing ONNX and Akida execution.
``HybridModel`` wraps a single ONNX graph and allows sections of the graph to be replaced
by Akida models. Integration is performed by identifying an input tensor (the entry point of
the Akida model) and an output tensor (where execution returns to the ONNX graph).
Akida segments are tracked internally so they can be mapped to hardware devices.
Args:
model (onnx.ModelProto or ir.Model): the base graph to augment.
name (str, optional): name of the hybrid model. Defaults to "HybridModel".
"""
def __init__(self, model, name="HybridModel"):
self.name = name
self._ak_models = {}
if isinstance(model, onnx.ModelProto):
model = ir.from_proto(model)
ir.passes.common.CheckerPass(full_check=True)(model)
self.model = model
self._tensors = ir.convenience.create_value_mapping(self.model.graph)
@property
def input_shape(self):
"""Returns the hybrid model input shape.
Returns:
list: the model input shape.
"""
return _get_ir_dimensions(self.model.graph.inputs)
@property
def output_shape(self):
"""Returns the hybrid model output shape.
Returns:
list: the model output shape.
"""
return _get_ir_dimensions(self.model.graph.outputs)
@property
def akida_models(self):
"""Returns a list of Akida models within the hybrid model.
Returns:
tuple: the akida models within the hybrid model.
"""
return tuple(self._ak_models.values())
def _add(self, model, incoming_value_name, outgoing_value_name):
"""Replace all nodes between incoming and outgoing values by an akida Model.
Args:
model (akida.Model): the model to add.
incoming_value_name (str): a tensor name used to link the model input.
outgoing_value_name (str): a tensor name used to link the model output.
"""
if not isinstance(model, akida.Model):
raise TypeError(f"Failed to add {model} to '{self.name}'. Expected an {akida.Model}.")
# Search incoming and outgoing values.
if (incoming_value := self._tensors.get(incoming_value_name)) is None:
raise ValueError(f"'{incoming_value_name}' not found in the model.")
if (outgoing_value := self._tensors.get(outgoing_value_name)) is None:
raise ValueError(f"'{outgoing_value_name}' not found in the model.")
# Check if model is able to be added.
self._check_model_integrity(model, incoming_value, outgoing_value)
# Store the model.
self._ak_models[(incoming_value_name, outgoing_value_name)] = model
[docs]
def map(self, device, mode=akida.MapMode.AllNps):
"""Map (if possible) all akida models to a given device.
Args:
device (akida.Device): An Akida device.
mode (akida.MapMode, optional): The mapping mode. Defaults to AllNps.
"""
for idx, model in enumerate(self.akida_models):
try:
model.map(device, hw_only=True, mode=mode)
except Exception as e:
raise RuntimeError(
f"Failed to map Akida model at index {idx} within 'akida_models'. "
f"Reason: {str(e)}."
)
[docs]
def generate_inference_model(self, dirpath="."):
"""Generates a unified ONNX inference model from all sub-models in the HybridModel.
This method inject all Akida sub-models added to the HybridModel inside of the ONNX model
suitable for inference. It handles the conversion of Akida models to ONNX
and connects sub-models according to their inbounds.
>>> inference_model = model.generate_inference_model()
>>> sess = AkidaInferenceSession(inference_model.SerializeToString())
>>> outputs = sess.run(None, feeds)
Args:
dirpath (str, optional): directory path where Akida programs will be saved.
Defaults to the current directory.
Returns:
onnx.ModelProto: the combined ONNX model ready for inference.
"""
def _is_akida_op(node):
return "AkidaOp" in node.op_type
def _added_by_conversion(node):
return node.metadata_props.get("tag", "") == "added on conversion"
if len(self._ak_models) == 0:
raise RuntimeError("At least one model is required to generate the inference model.")
if (len(self.akida_models) > 0 and
not all(self.akida_models[0].device == m.device for m in self.akida_models)):
raise RuntimeError("All akida models must be mapped on the same device.")
# Clone the original model to avoid modifying it.
ir_model = ir.from_proto(ir.to_proto(self.model))
ak_outgoing_names = {out_value for _, out_value in self._ak_models.keys()}
_all_outgoing_names = [v.name for n in ir_model.graph for v in n.outputs]
# Replace each akida segment by its ONNX equivalent.
# Note we sort models in reverse order to properly compute 'out_channel_last'.
_ak_models = sorted(self._ak_models.items(),
key=lambda x: _all_outgoing_names.index(x[0][1]),
reverse=True)
for sub_model_id, ((in_name, out_name), ak_model) in enumerate(_ak_models):
_tensors = ir.convenience.create_value_mapping(ir_model.graph)
in_value = _tensors[in_name]
out_value = _tensors[out_name]
# Extract nodes between in_value and out_value.
old_nodes = _find_nodes_between_values(in_value, out_value)
# Model requires:
# * input transpose when input comes from an ONNX model (akida expects channel-last)
# * output transpose if there is at least one ONNX consumer
# * flat the output if it is 2D (Akida outs 4D tensors)
in_channel_last = in_name in ak_outgoing_names
out_channel_last = all(_is_akida_op(node) for node in out_value.consumers())
flat_output = len(out_value.shape) == 2
# Change in_channel_last if another branch has already transposed the input.
if not in_channel_last:
for consumer in in_value.consumers():
if _added_by_conversion(consumer):
in_channel_last = True
# Link block nodes to the transpose output.
in_value = consumer.outputs[0]
in_name = in_value.name
old_nodes[0].replace_input_with(0, in_value)
break
# Convert the model to ONNX.
program_path = os.path.join(dirpath, f"program_{sub_model_id}.bin")
sub_model = convert_ak_model_into_onnx(ak_model,
program_path,
ir_input=in_value,
in_channel_last=in_channel_last,
out_channel_last=out_channel_last,
flat_output=flat_output)
for initializer in sub_model.initializers:
ir_model.graph.register_initializer(initializer)
# Replace nodes in the original graph.
# Note we overwrite out_value.shape since ir helper replace it by the old one.
new_out_value = sub_model.nodes[-1].outputs[0]
out_value.shape = new_out_value.shape
ir.convenience.replace_nodes_and_values(
ir_model.graph,
in_value.producer(),
old_nodes,
sub_model.nodes,
[out_value],
[new_out_value])
# Fix link between children that do not require transpose.
if not out_channel_last:
if (transpose_node := sub_model.nodes[-1]).op_type == "Flatten":
transpose_node = sub_model.nodes[-2]
for consumer in new_out_value.consumers():
if _is_akida_op(consumer):
# Link node to the input of transpose input.
consumer.replace_input_with(0, transpose_node.inputs[0])
# Convert ir_model to onnx.ModelProto
# Note we use RemoveUnusedNodesPass to clean unused initializers.
ir.passes.common.RemoveUnusedNodesPass()(ir_model)
model = ir.to_proto(ir_model)
# Sanity check.
onnx.checker.check_model(model, full_check=True)
return model
[docs]
def compute_data_movement(self):
"""Computes the data movement between CPU and Akida for models in the HybridModel.
For each Akida model sequence, this method calculates:
- The amount of data transferred from CPU to Akida.
- The amount of data transferred from Akida to CPU.
The size is computed as the product of the input or output dimensions.
Returns:
list of dict: a list of dictionaries, each containing :
- "layer": the Akida layer involved in the data transfer.
- "type": a string indicating the direction ("CPU -> Akida" or "Akida -> CPU").
- "size": The size in bytes of the data movement.
"""
data_movement = []
# Compute data movement for each sequence.
for ak_model in self.akida_models:
for seq in ak_model.sequences:
# First layer requires a data movement from CPU to Akida.
input_layer = seq.passes[0].layers[0]
factor = get_ir_input_dtype(input_layer).dtype.itemsize
data_movement.append({"layer": input_layer,
"type": "CPU -> Akida",
"size": math.prod(input_layer.input_dims) * factor})
# Last layer requires a data movement from Akida to CPU.
output_layer = seq.passes[-1].layers[-1]
factor = get_ir_output_dtype(output_layer).dtype.itemsize
data_movement.append({"layer": output_layer,
"type": "Akida -> CPU",
"size": math.prod(output_layer.output_dims) * factor})
return data_movement
[docs]
def summary(self):
"""Prints a string summary of the hybrid model.
This method prints a summary with details for every layer:
- Layer name and type
- Backend (ONNX or Akida)
- Output shape
- Inbounds (list of inbound layer names)
- Data movement (in bytes, only for layers involved in data transfer)
"""
# Prepare headers
headers = ['Layer (type)', 'Output shape', 'Inbounds', 'Data movement']
# Compute data movement and create a lookup dictionary
data_movement_info = self.compute_data_movement()
data_movement_map = {}
for dm in data_movement_info:
if dm["layer"].name not in data_movement_map:
data_movement_map[dm["layer"].name] = dm
else:
# If layer already exists, sum the sizes and change the type (for both directions)
data_movement_map[dm["layer"].name]["size"] += dm["size"]
data_movement_map[dm["layer"].name]["type"] = "CPU -> Akida -> CPU"
# Build inbound and outbound mapping
inbound_map = defaultdict(list)
outbound_map = defaultdict(list)
for node in self.model.graph:
for input_value in node.inputs:
# Skip initializers (weights, biases, etc.)
if not input_value.is_initializer():
inbound_map[node.name].append(input_value.name)
outbound_map[node.name].extend(out.name for out in node.outputs)
_tensors = ir.convenience.create_value_mapping(self.model.graph)
ak_trigger = {}
skip_nodes = []
for (in_value, out_value), ak_model in self._ak_models.items():
nodes = _find_nodes_between_values(_tensors[in_value], _tensors[out_value])
ak_trigger.update({nodes[0]: ak_model})
skip_nodes.extend(nodes[1:])
# Prepare table data
table = [headers]
new_splits = []
def _add_layer_to_table(layer, layer_idx, layer_type, output_shape, separator):
layer_name_str = f"{layer.name} ({layer_type})"
# Inbounds
inbounds = inbound_map.get(layer.name, [])
inbounds_str = ", ".join(inbounds) if inbounds else "N/A"
# Data movement
if layer.name in data_movement_map:
dm = data_movement_map[layer.name]
# Convert bytes to KB
size_kb = dm['size'] / 1024.0
data_movement_str = f"{size_kb:.2f} KB ({dm['type']})"
else:
data_movement_str = "N/A"
row = [layer_name_str, output_shape, inbounds_str, data_movement_str]
table.append(row)
# Add separator only before first layer of each model
new_splits.append(separator if layer_idx == 0 else False)
def _add_akida_model_to_table(ak_model, model_idx):
separator = f"--- Akida Sub-model {model_idx} ---"
layers = list(ak_model.layers)
# Remove InputData layers from Akida models
if layers[0].parameters.layer_type == akida.LayerType.InputData:
layers = layers[1:]
for layer_idx, layer in enumerate(layers):
layer_type = layer.parameters.layer_type.name
output_shape = str(list(layer.output_dims))
_add_layer_to_table(layer, layer_idx, layer_type, output_shape, separator)
def _add_onnx_sequence_to_table(onnx_sequence, model_idx):
separator = f"--- ONNX Sub-model {model_idx} ---"
for node_idx, node in enumerate(onnx_sequence):
layer_type = node.op_type
output_shape = (
str(list(node.outputs[0].shape[1:])) if node.outputs else "N/A")
_add_layer_to_table(node, node_idx, layer_type, output_shape, separator)
onnx_sequence = []
model_idx = 0
for node in self.model.graph:
if node in skip_nodes:
continue
if node in ak_trigger:
if onnx_sequence:
_add_onnx_sequence_to_table(onnx_sequence, model_idx)
model_idx += 1
onnx_sequence = []
_add_akida_model_to_table(ak_trigger[node], model_idx)
model_idx += 1
else:
onnx_sequence.append(node)
if onnx_sequence:
_add_onnx_sequence_to_table(onnx_sequence, model_idx)
# Print the table
_print_table(table, f"HybridModel Summary: {self.name}", new_splits)
def _check_model_integrity(self, model, incoming_value, outgoing_value):
def _check_match(ak_type, ak_shape, value):
assert ak_type == value.type, f"Type mismatch. Expected {value.dtype}, got {ak_type}."
# Check if value shape matches with akida expected format (4D channel last).
# Note that the nodes required to perform this transformation are added
# by generate_inference_model.
v_shape = _get_ir_dimensions([value])
if len(v_shape) < len(ak_shape):
# Fill missing dimensions to match with akida shape.
v_shape = v_shape + [1] * (len(ak_shape) - len(v_shape))
# Convert v_shape to channel-last format for comparison.
v_shape = v_shape[1:] + v_shape[:1]
assert v_shape == ak_shape, f"Shape mismatch. Expected {v_shape}, got {ak_shape}."
# For a model to be compatible, it must:
# * input model match with incoming value info.
try:
_check_match(get_ir_input_dtype(model.layers[0]), model.input_shape, incoming_value)
except Exception as e:
raise ValueError(f"Impossible to connect {incoming_value.name} with "
f"{model.layers[0].name}.") from e
# * output model match with outgoing value info.
try:
_check_match(get_ir_output_dtype(model.layers[-1]), model.output_shape, outgoing_value)
except Exception as e:
raise ValueError(f"Impossible to connect {outgoing_value.name} with "
f"{model.layers[-1].name}.") from e
def __call__(self, inputs, sess_options=None, providers=None, provider_options=None):
"""Runs inference on the hybrid model.
Args:
inputs (np.ndarray): input data for the model.
sess_options (onnxruntime.SessionOptions, optional): options for the ORT session.
Defaults to None.
providers (list, optional): list of execution providers for ORT. Defaults to None.
provider_options (list, optional): list of provider options for ORT.
Defaults to None.
"""
sess = onnxruntime.InferenceSession(ir.to_proto(self.model).SerializeToString(),
sess_options=sess_options, providers=providers,
provider_options=provider_options)
ort_inputs = {sess.get_inputs()[0].name: inputs}
return sess.run(None, ort_inputs)