#! /usr/bin/python3

import os
import sys
import subprocess
import multiprocessing

def thread(task):
  valgrind,o,p,n,offset = task

  command = ['ntruprime-test',o,p,'.',str(n),str(offset)]
  if valgrind:
    command = ['env','valgrind_multiplier=1','valgrind','-q','--error-exitcode=99']+command

  result = subprocess.run(command,stdout=subprocess.PIPE,stderr=subprocess.PIPE,universal_newlines=True)

  if result.returncode:
    return task,(f'nonzero return code {result.returncode}',result.stdout,result.stderr)

  if len(result.stderr) > 0: # bug in test program
    return task,(f'nonempty errors',result.stdout,result.stderr)

  saidsuccess = False
  saidimplementation = False
  saiddeclassify = False
  for line in result.stdout.splitlines():
    x = line.split()
    if x == ['all','tests','succeeded']: saidsuccess = True
    if len(x) >= 3 and x[2] == 'implementation': saidimplementation = True
    if x == ['valgrind','1','declassify','1']: saiddeclassify = True

  if not saidsuccess: # bug in test program
    return task,(f'did not say all tests succeeded',result.stdout,result.stderr)

  if not saidimplementation:
    return task,(f'CPU does not support implementation',result.stdout,result.stderr)

  if valgrind and not saiddeclassify:
    return task,(f'test does not support declassify',result.stdout,result.stderr)

  return task,(None,result.stdout,result.stderr)

def checkresult(task,result):
  valgrind,o,p,n,offset = task
  failure,out,err = result

  printtask = f'{o}/{p} impl {n} offset {offset}'
  printtask += ' dataflow' if valgrind else ' conventional'

  for desc,what in ('output',out),('error',err):
    for line in what.splitlines():
      print(f'{printtask} {desc}: {line}')

  if failure is None:
    print(f'{printtask} result: success')
  else:
    print(f'{printtask} result: {failure}')

  sys.stdout.flush()
  return failure is None

def doit(todo):
  todo = list(todo)
  print(f'tests to run: {len(todo)}')

  try:
    threads = len(os.sched_getaffinity(0))
  except:
    threads = multiprocessing.cpu_count()
  threads = os.getenv('THREADS',default=threads)
  threads = int(threads)
  threads = max(threads,1)
  threads = min(threads,len(todo))
  print(f'maximum threads allowed: {threads}')

  results = {}
  printpos = 0
  ok = True

  with multiprocessing.Pool(threads) as p:
    for task,result in p.imap_unordered(thread,todo,chunksize=1):
      results[task] = result
      while printpos < len(todo):
        task = todo[printpos]
        if task not in results: break
        if not checkresult(task,results[task]):
          ok = False
        printpos += 1

  assert printpos == len(todo)
  return ok

todo = [
  ('verify','897',-1),
 ('verify','897',0),
 ('verify','897',1),
  ('verify','1039',-1),
 ('verify','1039',0),
 ('verify','1039',1),
  ('verify','1184',-1),
 ('verify','1184',0),
 ('verify','1184',1),
  ('verify','1349',-1),
 ('verify','1349',0),
 ('verify','1349',1),
  ('verify','1455',-1),
 ('verify','1455',0),
 ('verify','1455',1),
  ('verify','1847',-1),
 ('verify','1847',0),
 ('verify','1847',1),
  ('decode','653x3',-1),
 ('decode','653x3',0),
 ('decode','653x3',1),
  ('decode','653x1541',-1),
 ('decode','653x1541',0),
 ('decode','653x1541',1),
  ('decode','653x4621',-1),
 ('decode','653x4621',0),
 ('decode','653x4621',1),
  ('decode','653xint16',-1),
 ('decode','653xint16',0),
 ('decode','653xint16',1),
  ('decode','653xint32',-1),
 ('decode','653xint32',0),
 ('decode','653xint32',1),
  ('decode','761x3',-1),
 ('decode','761x3',0),
 ('decode','761x3',1),
  ('decode','761x1531',-1),
 ('decode','761x1531',0),
 ('decode','761x1531',1),
  ('decode','761x4591',-1),
 ('decode','761x4591',0),
 ('decode','761x4591',1),
  ('decode','761xint16',-1),
 ('decode','761xint16',0),
 ('decode','761xint16',1),
  ('decode','761xint32',-1),
 ('decode','761xint32',0),
 ('decode','761xint32',1),
  ('decode','857x3',-1),
 ('decode','857x3',0),
 ('decode','857x3',1),
  ('decode','857x1723',-1),
 ('decode','857x1723',0),
 ('decode','857x1723',1),
  ('decode','857x5167',-1),
 ('decode','857x5167',0),
 ('decode','857x5167',1),
  ('decode','857xint16',-1),
 ('decode','857xint16',0),
 ('decode','857xint16',1),
  ('decode','857xint32',-1),
 ('decode','857xint32',0),
 ('decode','857xint32',1),
  ('decode','953x3',-1),
 ('decode','953x3',0),
 ('decode','953x3',1),
  ('decode','953x2115',-1),
 ('decode','953x2115',0),
 ('decode','953x2115',1),
  ('decode','953x6343',-1),
 ('decode','953x6343',0),
 ('decode','953x6343',1),
  ('decode','953xint16',-1),
 ('decode','953xint16',0),
 ('decode','953xint16',1),
  ('decode','953xint32',-1),
 ('decode','953xint32',0),
 ('decode','953xint32',1),
  ('decode','1013x3',-1),
 ('decode','1013x3',0),
 ('decode','1013x3',1),
  ('decode','1013x2393',-1),
 ('decode','1013x2393',0),
 ('decode','1013x2393',1),
  ('decode','1013x7177',-1),
 ('decode','1013x7177',0),
 ('decode','1013x7177',1),
  ('decode','1013xint16',-1),
 ('decode','1013xint16',0),
 ('decode','1013xint16',1),
  ('decode','1013xint32',-1),
 ('decode','1013xint32',0),
 ('decode','1013xint32',1),
  ('decode','1277x3',-1),
 ('decode','1277x3',0),
 ('decode','1277x3',1),
  ('decode','1277x2627',-1),
 ('decode','1277x2627',0),
 ('decode','1277x2627',1),
  ('decode','1277x7879',-1),
 ('decode','1277x7879',0),
 ('decode','1277x7879',1),
  ('decode','1277xint16',-1),
 ('decode','1277xint16',0),
 ('decode','1277xint16',1),
  ('decode','1277xint32',-1),
 ('decode','1277xint32',0),
 ('decode','1277xint32',1),
  ('decode','int16',-1),
 ('decode','int16',0),
 ('decode','int16',1),
  ('encode','653x3',-1),
 ('encode','653x3',0),
 ('encode','653x3',1),
  ('encode','653x1541',-1),
 ('encode','653x1541',0),
 ('encode','653x1541',1),
  ('encode','653x1541round',-1),
 ('encode','653x1541round',0),
 ('encode','653x1541round',1),
  ('encode','653x4621',-1),
 ('encode','653x4621',0),
 ('encode','653x4621',1),
  ('encode','653xfreeze3',-1),
 ('encode','653xfreeze3',0),
 ('encode','653xfreeze3',1),
  ('encode','653xint16',-1),
 ('encode','653xint16',0),
 ('encode','653xint16',1),
  ('encode','761x3',-1),
 ('encode','761x3',0),
 ('encode','761x3',1),
  ('encode','761x1531',-1),
 ('encode','761x1531',0),
 ('encode','761x1531',1),
  ('encode','761x1531round',-1),
 ('encode','761x1531round',0),
 ('encode','761x1531round',1),
  ('encode','761x4591',-1),
 ('encode','761x4591',0),
 ('encode','761x4591',1),
  ('encode','761xfreeze3',-1),
 ('encode','761xfreeze3',0),
 ('encode','761xfreeze3',1),
  ('encode','761xint16',-1),
 ('encode','761xint16',0),
 ('encode','761xint16',1),
  ('encode','857x3',-1),
 ('encode','857x3',0),
 ('encode','857x3',1),
  ('encode','857x1723',-1),
 ('encode','857x1723',0),
 ('encode','857x1723',1),
  ('encode','857x1723round',-1),
 ('encode','857x1723round',0),
 ('encode','857x1723round',1),
  ('encode','857x5167',-1),
 ('encode','857x5167',0),
 ('encode','857x5167',1),
  ('encode','857xfreeze3',-1),
 ('encode','857xfreeze3',0),
 ('encode','857xfreeze3',1),
  ('encode','857xint16',-1),
 ('encode','857xint16',0),
 ('encode','857xint16',1),
  ('encode','953x3',-1),
 ('encode','953x3',0),
 ('encode','953x3',1),
  ('encode','953x2115',-1),
 ('encode','953x2115',0),
 ('encode','953x2115',1),
  ('encode','953x2115round',-1),
 ('encode','953x2115round',0),
 ('encode','953x2115round',1),
  ('encode','953x6343',-1),
 ('encode','953x6343',0),
 ('encode','953x6343',1),
  ('encode','953xfreeze3',-1),
 ('encode','953xfreeze3',0),
 ('encode','953xfreeze3',1),
  ('encode','953xint16',-1),
 ('encode','953xint16',0),
 ('encode','953xint16',1),
  ('encode','1013x3',-1),
 ('encode','1013x3',0),
 ('encode','1013x3',1),
  ('encode','1013x2393',-1),
 ('encode','1013x2393',0),
 ('encode','1013x2393',1),
  ('encode','1013x2393round',-1),
 ('encode','1013x2393round',0),
 ('encode','1013x2393round',1),
  ('encode','1013x7177',-1),
 ('encode','1013x7177',0),
 ('encode','1013x7177',1),
  ('encode','1013xfreeze3',-1),
 ('encode','1013xfreeze3',0),
 ('encode','1013xfreeze3',1),
  ('encode','1013xint16',-1),
 ('encode','1013xint16',0),
 ('encode','1013xint16',1),
  ('encode','1277x3',-1),
 ('encode','1277x3',0),
 ('encode','1277x3',1),
  ('encode','1277x2627',-1),
 ('encode','1277x2627',0),
 ('encode','1277x2627',1),
  ('encode','1277x2627round',-1),
 ('encode','1277x2627round',0),
 ('encode','1277x2627round',1),
  ('encode','1277x7879',-1),
 ('encode','1277x7879',0),
 ('encode','1277x7879',1),
  ('encode','1277xfreeze3',-1),
 ('encode','1277xfreeze3',0),
 ('encode','1277xfreeze3',1),
  ('encode','1277xint16',-1),
 ('encode','1277xint16',0),
 ('encode','1277xint16',1),
  ('encode','int16',-1),
 ('encode','int16',0),
  ('sort','int32',-1),
 ('sort','int32',0),
 ('sort','int32',1),
  ('sort','uint32',-1),
 ('sort','uint32',0),
 ('sort','uint32',1),
  ('core','inv3sntrup653',-1),
 ('core','inv3sntrup653',0),
 ('core','inv3sntrup653',1),
  ('core','inv3sntrup761',-1),
 ('core','inv3sntrup761',0),
 ('core','inv3sntrup761',1),
  ('core','inv3sntrup857',-1),
 ('core','inv3sntrup857',0),
 ('core','inv3sntrup857',1),
  ('core','inv3sntrup953',-1),
 ('core','inv3sntrup953',0),
 ('core','inv3sntrup953',1),
  ('core','inv3sntrup1013',-1),
 ('core','inv3sntrup1013',0),
 ('core','inv3sntrup1013',1),
  ('core','inv3sntrup1277',-1),
 ('core','inv3sntrup1277',0),
 ('core','inv3sntrup1277',1),
  ('core','invsntrup653',-1),
 ('core','invsntrup653',0),
 ('core','invsntrup653',1),
  ('core','invsntrup761',-1),
 ('core','invsntrup761',0),
 ('core','invsntrup761',1),
  ('core','invsntrup857',-1),
 ('core','invsntrup857',0),
 ('core','invsntrup857',1),
  ('core','invsntrup953',-1),
 ('core','invsntrup953',0),
 ('core','invsntrup953',1),
  ('core','invsntrup1013',-1),
 ('core','invsntrup1013',0),
 ('core','invsntrup1013',1),
  ('core','invsntrup1277',-1),
 ('core','invsntrup1277',0),
 ('core','invsntrup1277',1),
  ('core','mult3sntrup653',-1),
 ('core','mult3sntrup653',0),
 ('core','mult3sntrup653',1),
  ('core','mult3sntrup761',-1),
 ('core','mult3sntrup761',0),
 ('core','mult3sntrup761',1),
  ('core','mult3sntrup857',-1),
 ('core','mult3sntrup857',0),
 ('core','mult3sntrup857',1),
  ('core','mult3sntrup953',-1),
 ('core','mult3sntrup953',0),
 ('core','mult3sntrup953',1),
  ('core','mult3sntrup1013',-1),
 ('core','mult3sntrup1013',0),
 ('core','mult3sntrup1013',1),
  ('core','mult3sntrup1277',-1),
 ('core','mult3sntrup1277',0),
 ('core','mult3sntrup1277',1),
  ('core','multsntrup653',-1),
 ('core','multsntrup653',0),
 ('core','multsntrup653',1),
  ('core','multsntrup761',-1),
 ('core','multsntrup761',0),
 ('core','multsntrup761',1),
  ('core','multsntrup857',-1),
 ('core','multsntrup857',0),
 ('core','multsntrup857',1),
  ('core','multsntrup953',-1),
 ('core','multsntrup953',0),
 ('core','multsntrup953',1),
  ('core','multsntrup1013',-1),
 ('core','multsntrup1013',0),
 ('core','multsntrup1013',1),
  ('core','multsntrup1277',-1),
 ('core','multsntrup1277',0),
 ('core','multsntrup1277',1),
  ('core','scale3sntrup653',-1),
 ('core','scale3sntrup653',0),
 ('core','scale3sntrup653',1),
  ('core','scale3sntrup761',-1),
 ('core','scale3sntrup761',0),
 ('core','scale3sntrup761',1),
  ('core','scale3sntrup857',-1),
 ('core','scale3sntrup857',0),
 ('core','scale3sntrup857',1),
  ('core','scale3sntrup953',-1),
 ('core','scale3sntrup953',0),
 ('core','scale3sntrup953',1),
  ('core','scale3sntrup1013',-1),
 ('core','scale3sntrup1013',0),
 ('core','scale3sntrup1013',1),
  ('core','scale3sntrup1277',-1),
 ('core','scale3sntrup1277',0),
 ('core','scale3sntrup1277',1),
  ('core','weightsntrup653',-1),
 ('core','weightsntrup653',0),
 ('core','weightsntrup653',1),
  ('core','weightsntrup761',-1),
 ('core','weightsntrup761',0),
 ('core','weightsntrup761',1),
  ('core','weightsntrup857',-1),
 ('core','weightsntrup857',0),
 ('core','weightsntrup857',1),
  ('core','weightsntrup953',-1),
 ('core','weightsntrup953',0),
 ('core','weightsntrup953',1),
  ('core','weightsntrup1013',-1),
 ('core','weightsntrup1013',0),
 ('core','weightsntrup1013',1),
  ('core','weightsntrup1277',-1),
 ('core','weightsntrup1277',0),
 ('core','weightsntrup1277',1),
  ('core','wforcesntrup653',-1),
 ('core','wforcesntrup653',0),
 ('core','wforcesntrup653',1),
  ('core','wforcesntrup761',-1),
 ('core','wforcesntrup761',0),
 ('core','wforcesntrup761',1),
  ('core','wforcesntrup857',-1),
 ('core','wforcesntrup857',0),
 ('core','wforcesntrup857',1),
  ('core','wforcesntrup953',-1),
 ('core','wforcesntrup953',0),
 ('core','wforcesntrup953',1),
  ('core','wforcesntrup1013',-1),
 ('core','wforcesntrup1013',0),
 ('core','wforcesntrup1013',1),
  ('core','wforcesntrup1277',-1),
 ('core','wforcesntrup1277',0),
 ('core','wforcesntrup1277',1),
  ('hashblocks','sha512',-1),
 ('hashblocks','sha512',0),
 ('hashblocks','sha512',1),
  ('hash','sha512',-1),
 ('hash','sha512',0),
  ('kem','sntrup653',-1),
 ('kem','sntrup653',0),
 ('kem','sntrup653',1),
  ('kem','sntrup761',-1),
 ('kem','sntrup761',0),
 ('kem','sntrup761',1),
  ('kem','sntrup857',-1),
 ('kem','sntrup857',0),
 ('kem','sntrup857',1),
  ('kem','sntrup953',-1),
 ('kem','sntrup953',0),
  ('kem','sntrup1013',-1),
 ('kem','sntrup1013',0),
 ('kem','sntrup1013',1),
  ('kem','sntrup1277',-1),
 ('kem','sntrup1277',0),
 ('kem','sntrup1277',1),
]

todo = [task+(offset,) for task in todo for offset in range(1 if task[0] in ("encode","decode","sort") else 2)]
todo = [(False,)+task for task in todo]+[(True,)+task for task in todo if task[-1] == 0]

if not doit(todo):
  print('some tests failed')
  sys.exit(111)
print('full tests succeeded')
