# @file wots_gen_pk.py
# @version 1.1 (2026-01-23T12:43Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""Compute root node of the top-most HT subtree = PK.root."""
from slh_adrs import Adrs
from slh_sha256 import PRF, T_len
from slh_util import chain, hash_root

# Global vars
PKseed = 'FA495FB834DEFEA7CC96A81309479135'
SKseed = 'D5213BA4BB6470F1B9EDA88CBC94E627'
PKroot = 'A67029E90668C5A58B96E60111491F3D'  # Expected result
w = 16
len = 35
a = 3
t = 2 ** a
iscompr = True

# Compute leaves of top-most subtree at layer 21 with tree_addr 0
tree_addr = 0
this_layer = 21
leaves = [];
for leaf_idx in range(t):
    print(f"leaf_idx={leaf_idx}")
    adrs = Adrs(Adrs.WOTS_HASH, layer=this_layer)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(leaf_idx)
    print(adrs.toHex(iscompr))
    skAdrs = adrs.copy()
    skAdrs.setType(Adrs.WOTS_PRF)
    skAdrs.setKeyPairAddress(adrs.getKeyPairAddress())
    heads = ""  # concatenation of heads of WOTS+ chains
    for chainaddr in range(len):
        # [v3.1] Use WOTS_PRF to create sk, but WOTS_HASH for pk
        skAdrs.setChainAddress(chainaddr)
        sk_adrs_hex = skAdrs.toHex(iscompr)
        print(f"sk_adrs={sk_adrs_hex}")
        # Compute secret value for chain i
        sk = PRF(PKseed, SKseed, sk_adrs_hex)
        print(f"sk[{chainaddr}]={sk}")
        # Compute public value for chain i
        adrs.setChainAddress(chainaddr)
        pk_adrs_hex = adrs.toHex(iscompr)
        print(f"pk_adrs={pk_adrs_hex}")
        pk = chain(sk, 0, w - 1, PKseed, pk_adrs_hex,
                   showdebug=(chainaddr < 2 or chainaddr == 34));
        print(f"pk={pk}")
        if leaf_idx == 0 and chainaddr == 0:
            print(f"OK=c99a06d927e9b37f48dc68e3a867ea42")
        if leaf_idx == 0 and chainaddr == 34:
            print(f"OK=e9f5cec1ef3b3597e23cee7e9249145a")
        heads += pk

    print(f"Input to thash:\n{heads}")
    # for thash, 
    #wots_pk_addr = "15000000000000000001000000000000000000000000"
    wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=this_layer)
    wots_pk_adrs.setTreeAddress(tree_addr)
    wots_pk_adrs.setKeyPairAddress(leaf_idx)
    wots_pk_addr_hex = wots_pk_adrs.toHex(True)
    print(f"wots_pk_addr={wots_pk_addr_hex}")
    leaf = T_len(PKseed, wots_pk_addr_hex, heads)
    print(f"leaf[{leaf_idx}]={leaf}")
    #print(f"OK  =TODO")
    leaves.append(leaf)

print(leaves)
assert(leaves[0] == "2bd8631b653542190503f84e352c7494")
assert(leaves[7] == "f9247aa899cf41af17a42668b30c1573")


# Compute the root node of Merkle tree using H
# Start with 8 leaf values in array
hashes = leaves
# Top-most tree out of 22, layer=21=0x15
adrs = Adrs(Adrs.TREE, layer=this_layer)

root = hash_root(hashes, adrs, PKseed)
print(f"root={root}")
print(f"OK  ={PKroot}")