#! /usr/bin/python3
# Last edited on 2024-07-27 00:35:00 by stolfi

# Test program for module {trafo}

import trafo
import sys
from math import sqrt, sin, cos, floor, ceil, inf, nan, pi

class MyTrafo(trafo.Trafo):
  pass
   
sys.stderr.write("--- testing {trafo.identity,trafo.is_identity} ----------------\n")

T0 = trafo.identity()
sys.stderr.write("T0 = %s\n" % trafo.get_name(T0))
assert trafo.get_name(T0) == "None"
assert trafo.is_identity(T0)
assert trafo.is_identity((+1,))
assert trafo.is_identity((-1,))

name1 = "T1"
T1 = trafo.Trafo(name1)
sys.stderr.write("T1 = %s\n" % trafo.get_name(T1))
assert isinstance(T1, trafo.Trafo)
assert trafo.get_name(T1) == name1
assert not trafo.is_identity(T1)
assert not trafo.is_identity((+1,T1))
assert not trafo.is_identity((-1,T1))

sys.stderr.write("--- testing {trafo.unpack} ----------------\n")

kinv0, TL0 = trafo.unpack(T0)
assert kinv0 == +1
assert len(TL0) == 0

kinv1, TL1 = trafo.unpack(T1)
assert kinv1 == +1
assert len(TL1) == 1
assert TL1[0] == T1

sys.stderr.write("--- testing {trafo.inv} ----------------\n")

T0C = trafo.inv(T0)
sys.stderr.write("T0C = %s\n" % trafo.get_name(T0C))
assert T0C == None
assert trafo.inv((+1,)) == None
assert trafo.inv((-1,)) == None

T1C = trafo.inv(T1)
sys.stderr.write("T1C = %s\n" % trafo.get_name(T1C))
assert trafo.get_name(T1C) == name1 + "'"
assert not trafo.is_identity(T1C)
assert T1C == (-1,T1)
assert trafo.inv((+1,T1)) == T1C
assert trafo.inv((-1,T1)) == T1
assert trafo.inv(T1C) == T1

kinv1C, TL1C = trafo.unpack(T1C)
assert kinv1C == -1
assert len(TL1C) == 1
assert TL1C[0] == T1
 
sys.stderr.write("--- testing {get_name,set_name} ----------------\n")

trafo.set_name(T1, "Foo")
sys.stderr.write("trafo.get_name(T1) == %s\n" % trafo.get_name(T1))
assert trafo.get_name(T1) == "Foo"
assert trafo.get_name(T1C) == "Foo'"
 
trafo.set_name(T1C, "Bar")
assert trafo.get_name(T1) == "Bar'"
assert trafo.get_name(T1C) == "Bar"
  
trafo.set_name(T1, name1)

sys.stderr.write("--- testing {trafo.compose} ----------------\n")

T2 = trafo.Trafo("T2")
assert not trafo.is_identity(T2)

Tc = trafo.compose(())
assert trafo.is_identity(Tc)
Td = trafo.compose([])
assert trafo.is_identity(Td)

Tc1 = trafo.compose((T1,))
assert Tc1 == T1
Td1 = trafo.compose([T1])
assert Td1 == T1

Tc2 = trafo.compose((trafo.inv(T2),))
assert Tc2 == trafo.inv(T2)
Td2 = trafo.compose([trafo.inv(T2)])
assert Td2 == trafo.inv(T2)

Tc12 = trafo.compose((T1,T2))
sys.stderr.write("Tc12 = %s\n" % trafo.get_name(Tc12))
assert trafo.get_name(Tc12) == "(T1 T2)"
assert Tc12 == (+1,T1,T2)
Td12 = trafo.compose([T1,T2])
sys.stderr.write("Td12 = %s\n" % trafo.get_name(Td12))
assert trafo.get_name(Td12) == "(T1 T2)"
assert Td12 == (+1,T1,T2)

kinvc12, TLc12 = trafo.unpack(Tc12)
assert kinvc12 == +1
assert TLc12 == (T1,T2)


Tc21 = trafo.inv(Tc12)
sys.stderr.write("Tc21 = %s\n" % trafo.get_name(Tc21))
assert trafo.get_name(Tc21) == "(T1 T2)'"
assert Tc21 == (-1,T1,T2)

kinvc21, TLc21 = trafo.unpack(Tc21)
assert kinvc21 == -1
assert TLc21 == (T1,T2)


Tc1222 = trafo.compose((T1,T2,trafo.inv(T2),T2))
sys.stderr.write("Tc1222 = %s\n" % trafo.get_name(Tc1222))
assert trafo.get_name(Tc1222) == "(T1 T2)"
assert Tc1222 == (+1,T1,T2)

kinvc1222, TLc1222 = trafo.unpack(Tc1222)
assert kinvc1222 == +1
assert TLc1222 == (T1,T2)

sys.stderr.write("--- testing {trafo.flatten} ----------------\n")

T0Lf = trafo.flatten(T0)
assert T0Lf == []

T1Lf = trafo.flatten(T1)
sys.stderr.write("T1Lf = %s\n" % str(T1Lf))
assert T1Lf == [T1]

T1CLf = trafo.flatten(T1C)
sys.stderr.write("T1CLf = %s\n" % str(T1CLf))
assert T1CLf == [ (-1, T1) ]

Tdup = (-1, T1, T2)
TdupLf = trafo.flatten(Tdup)
sys.stderr.write("TdupLf = %s\n" % str(TdupLf))
assert TdupLf == [ (-1, T2), (-1, T1) ]

Tnnn = (-1, (-1, (-1, T1)))
TnnnLf = trafo.flatten(Tnnn)
sys.stderr.write("TnnnLf = %s\n" % str(TnnnLf))
assert TnnnLf == [ (-1, T1) ]

Tnnd = (-1, (-1, (+1, T1, T2)))
TnndLf = trafo.flatten(Tnnd)
sys.stderr.write("TnndLf = %s\n" % str(TnndLf))
assert TnndLf == [ T1, T2 ]

T3 = trafo.Trafo("T3")

Tbig = (-1,(+1,T1,T3),(-1,(-1,T2,T3),(-1,trafo.inv(T3),trafo.inv(T2),T3,trafo.inv(T3),T2)))
sys.stderr.write("Tbig = %s\n" % trafo.get_name(Tbig)) 

TbigLf = trafo.flatten(Tbig)
# Tbig = "((T1 T3) (T3' T2' (T3' T2' T3 T3' T2)' )' )'"
# Tbig = "(T1 T3   (T3' T2' (T2' T3 T3' T2 T3)   )' )'"
# Tbig = "(T1 T3 (T3' T2' T2' T3 T3' T2 T3)' )'"
# Tbig = "(T1 T3 T3' T2' T3 T3' T2 T2 T3)'"
# Tbig = "(T3' T2' T2' T3 T3' T2 T3 T3' T1')"

assert type(TbigLf) is list
sys.stderr.write("TbigLf = [\n")
for Tk in TbigLf:
  sys.stderr.write("  %s\n" % trafo.get_name(Tk))
  assert Tk != None
  if not isinstance(Tk, trafo.Trafo):
    assert type(Tk) is tuple and len(Tk) == 2
    assert Tk[0] == -1
    assert Tk[1] != None
    assert isinstance(Tk[1], trafo.Trafo)
sys.stderr.write("]\n")
assert len(TbigLf) == 9
assert TbigLf == [(-1, T3), (-1, T2), (-1, T2), T3, (-1, T3), T2, T3, (-1, T3), (-1, T1)] 

sys.stderr.write("--- testing {trafo.reduce} ----------------\n")

TbigLr = trafo.reduce(TbigLf)
# Tbig = "(T3' T2' T1')"

assert type(TbigLr) is list
sys.stderr.write("TbigLr = [\n")
for Tk in TbigLr:
  sys.stderr.write("  %s\n" % trafo.get_name(Tk))
  assert Tk != None
  if not isinstance(Tk, trafo.Trafo):
    assert type(Tk) is tuple and len(Tk) == 2
    assert Tk[0] == -1
    assert Tk[1] != None
    assert isinstance(Tk[1], trafo.Trafo)
sys.stderr.write("]\n")
assert len(TbigLr) == 3
assert TbigLr == [(-1, T3), (-1, T2), (-1, T1)] 

sys.stderr.write("--- testing {trafo.simplify} ----------------\n")

Tbigs = trafo.simplify(Tbig)
sys.stderr.write("Tbigs = %s\n" % trafo.get_name(Tbigs))

kinvbigs, TLbigs = trafo.unpack(Tbigs)
assert kinvbigs == +1
assert len(TLbigs) == 3
assert TLbigs[0] == trafo.inv(T3)
assert TLbigs[1] == trafo.inv(T2)
assert TLbigs[2] == trafo.inv(T1)
