# @file wots_gen_pk.py
# @version 1.1 (2026-01-23T12:43Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0
"""Compute root node of the top-most HT subtree = PK.root."""
from slh_adrs import Adrs
from slh_sha256 import PRF, T_len
from slh_util import chain, hash_root
# Global vars
PKseed = 'FA495FB834DEFEA7CC96A81309479135'
SKseed = 'D5213BA4BB6470F1B9EDA88CBC94E627'
PKroot = 'A67029E90668C5A58B96E60111491F3D' # Expected result
w = 16
len = 35
a = 3
t = 2 ** a
iscompr = True
# Compute leaves of top-most subtree at layer 21 with tree_addr 0
tree_addr = 0
this_layer = 21
leaves = [];
for leaf_idx in range(t):
print(f"leaf_idx={leaf_idx}")
adrs = Adrs(Adrs.WOTS_HASH, layer=this_layer)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(leaf_idx)
print(adrs.toHex(iscompr))
skAdrs = adrs.copy()
skAdrs.setType(Adrs.WOTS_PRF)
skAdrs.setKeyPairAddress(adrs.getKeyPairAddress())
heads = "" # concatenation of heads of WOTS+ chains
for chainaddr in range(len):
# [v3.1] Use WOTS_PRF to create sk, but WOTS_HASH for pk
skAdrs.setChainAddress(chainaddr)
sk_adrs_hex = skAdrs.toHex(iscompr)
print(f"sk_adrs={sk_adrs_hex}")
# Compute secret value for chain i
sk = PRF(PKseed, SKseed, sk_adrs_hex)
print(f"sk[{chainaddr}]={sk}")
# Compute public value for chain i
adrs.setChainAddress(chainaddr)
pk_adrs_hex = adrs.toHex(iscompr)
print(f"pk_adrs={pk_adrs_hex}")
pk = chain(sk, 0, w - 1, PKseed, pk_adrs_hex,
showdebug=(chainaddr < 2 or chainaddr == 34));
print(f"pk={pk}")
if leaf_idx == 0 and chainaddr == 0:
print(f"OK=c99a06d927e9b37f48dc68e3a867ea42")
if leaf_idx == 0 and chainaddr == 34:
print(f"OK=e9f5cec1ef3b3597e23cee7e9249145a")
heads += pk
print(f"Input to thash:\n{heads}")
# for thash,
#wots_pk_addr = "15000000000000000001000000000000000000000000"
wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=this_layer)
wots_pk_adrs.setTreeAddress(tree_addr)
wots_pk_adrs.setKeyPairAddress(leaf_idx)
wots_pk_addr_hex = wots_pk_adrs.toHex(True)
print(f"wots_pk_addr={wots_pk_addr_hex}")
leaf = T_len(PKseed, wots_pk_addr_hex, heads)
print(f"leaf[{leaf_idx}]={leaf}")
#print(f"OK =TODO")
leaves.append(leaf)
print(leaves)
assert(leaves[0] == "2bd8631b653542190503f84e352c7494")
assert(leaves[7] == "f9247aa899cf41af17a42668b30c1573")
# Compute the root node of Merkle tree using H
# Start with 8 leaf values in array
hashes = leaves
# Top-most tree out of 22, layer=21=0x15
adrs = Adrs(Adrs.TREE, layer=this_layer)
root = hash_root(hashes, adrs, PKseed)
print(f"root={root}")
print(f"OK ={PKroot}")