#!/bin/bash -e

#
# load parameters
#
source config.in

#
# run in batch
#

if [ $# -eq 0 -a "$run_cmd" = "batch_cmd" ]; then
    test_name="$(basename $0)"
    ${run_cmd} ./$0 1
    exit
fi

dims_small=`echo ${dims_small} | tr "\n" " "` # remove \n

#
# compile the generator of small mults
#
${host_compile} -c mults.f90 
${host_compile} -c multrec_gen.f90 
${host_compile} mults.o multrec_gen.o lib_gen.f90 -o lib_gen.x

#
# directory for the library
#
mkdir -p lib

#
# Create the directory to store all files
#
lib_dir="run_lib"${type_label}
mkdir -p ${lib_dir}

#
# generate the generic caller
#

maxsize=-1
numsize=0

for myn in ${dims_small}
do
  numsize=$((numsize+1))
  maxsize=`echo "$myn $maxsize" | awk '{if ($1>$2) { print $1 } else { print $2 } }'`
done

#
# generate a translation array for the jump table
#
count=0
eles="(/0"
for i in `seq 1 $maxsize`
do

  found=0
  for myn in ${dims_small}
  do
      if [ "$myn" == "$i" ]; then
         found=1
      fi
  done
  if [ "$found" == 1 ]; then
     count=$((count+1))
     ele=$count
  else
     ele=0 
  fi
  eles="$eles,$ele"
done
eles="$eles/)"

if [[ "$data_type" == "1" ]]; then
    filetype="d"
    gemm="DGEMM"
    strdat="REAL(KIND=KIND(0.0D0))"
fi
if [[ "$data_type" == "2" ]]; then
    filetype="s"
    gemm="SGEMM"
    strdat="REAL(KIND=KIND(0.0))"
fi
if [[ "$data_type" == "3" ]]; then
    filetype="z"
    gemm="ZGEMM"
    strdat="COMPLEX(KIND=KIND(0.0D0))"
fi
if [[ "$data_type" == "4" ]]; then
    filetype="c"
    gemm="CGEMM"
    strdat="COMPLEX(KIND=KIND(0.0))"
fi



if [[ "$transpose_flavor" == "1" ]]; then
    filetrans="nn"
    ta="N"
    tb="N"
    decl="A(M,K), B(K,N)"
    lds="LDA=M ; LDB=K"
fi
if [[ "$transpose_flavor" == "2" ]]; then
    filetrans="tn"
    ta="T"
    tb="N"
    decl="A(K,M), B(K,N)"
    lds="LDA=K ; LDB=K"
fi
if [[ "$transpose_flavor" == "3" ]]; then
    filetrans="nt"
    ta="N"
    tb="T"
    decl="A(M,K), B(N,K)"
    lds="LDA=M ; LDB=N"
fi
if [[ "$transpose_flavor" == "4" ]]; then
    filetrans="tt"
    ta="T"
    tb="T"
    decl="A(K,M), B(N,K)"
    lds="LDA=K ; LDB=N"
fi

filetrans="${filetype}${filetrans}"
file="smm_${filetrans}.f90"

cd ${lib_dir}

rm -f ${file}

printf "SUBROUTINE smm_${filetrans}(M,N,K,A,B,C)\n" >> ${file}
printf " INTEGER :: M,N,K,LDA,LDB\n ${strdat} :: C(M,N), ${decl}\n" >> ${file}
printf " INTEGER, PARAMETER :: indx(0:$maxsize)=&\n $eles\n" >> ${file}
printf " INTEGER :: im,in,ik,itot\n" >> ${file}
printf " $strdat, PARAMETER :: one=1\n" >> ${file}
printf " ${lds}\n" >> ${file}
printf " IF (M<=$maxsize) THEN\n   im=indx(M)\n ELSE\n   im=0\n ENDIF\n" >> ${file}
printf " IF (N<=$maxsize) THEN\n   in=indx(N)\n ELSE\n   in=0\n ENDIF\n" >> ${file}
printf " IF (K<=$maxsize) THEN\n   ik=indx(K)\n ELSE\n   ik=0\n ENDIF\n" >> ${file}
printf " itot=(ik*($numsize+1)+in)*($numsize+1)+im\n" >> ${file}

count=0
printf " SELECT CASE(itot)\n" >> ${file}
for myk in 0 ${dims_small}
do
for myn in 0 ${dims_small}
do
for mym in 0 ${dims_small}
do
printf " CASE($count)\n " >> ${file}
prod=$((myk*myn*mym))
if [[ "$prod" == "0" ]]; then
  printf '   GOTO 999\n' >> ${file}
else
  printf "   CALL smm_${filetrans}_${mym}_${myn}_${myk}(A,B,C)\n" >> ${file}
fi
count=$((count+1))
done
done
done
printf " END SELECT\n" >> ${file}
printf " RETURN\n" >> ${file}
printf "999 CONTINUE \n CALL ${gemm}('%s','%s',M,N,K,one,A,LDA,B,LDB,one,C,M)\n" $ta $tb >> ${file}
printf "END SUBROUTINE smm_${filetrans}\n\n" >> ${file}

cd ..


make_file="Makefile.lib"${type_label}
library="smm"${type_label}
if [ -n "${mic+xxxx}" ]; then
make_file="Makefile_${mic}.lib"${type_label}
library=${library}_${mic}
fi
archive="../lib/lib${library}.a"

#
# for easy parallelism go via a Makefile
#
rm -f ${make_file}

#
# Input tiny and small optimal file
#
in_tiny_file="tiny_gen_optimal"
in_small_file="small_gen_optimal"
if [ -n "${mic+xxxx}" ]; then
in_tiny_file="${in_tiny_file}_${mic}"
in_small_file="${in_small_file}_${mic}"
fi
in_tiny_file=${in_tiny_file}${type_label}
in_small_file=${in_small_file}${type_label}

(
#
# a two stage approach, first compile in parallel, once done,
# execute in parallel
#

#
# generate the list of indices files
#
printf "DIMS_SMALL    = ${dims_small}\n"
printf "DIMS_INDICES = \$(foreach m,\$(DIMS_SMALL),\$(foreach n,\$(DIMS_SMALL),\$(foreach k,\$(DIMS_SMALL),\$m_\$n_\$k)))"
printf "\n\n"

#
# consider only a sub-range of all indices
#
printf "SI = 1\n"
printf "EI = \$(words \$(DIMS_INDICES))\n"
printf "INDICES = \$(wordlist \$(SI),\$(EI),\$(DIMS_INDICES))\n\n"

#
# output directory for compiled and results files
#
printf "OUTDIR=${out_dir}\n\n"

#
# driver file
#
printf "DRIVER=${file%.*}.o \n\n"

#
# list of source files
#
printf "SRCFILES=\$(patsubst %%,smm_${filetrans}_%%.f90,\$(INDICES)) \n"

#
# list of obj files
#
printf "OBJFILES=\$(patsubst %%,\$(OUTDIR)/smm_${filetrans}_%%.o,\$(INDICES)) \n\n"

printf ".PHONY: \$(OUTDIR)/\$(DRIVER) \n\n"

printf "all: archive \n\n"

#
# generation source rule
#
printf "source: \$(SRCFILES) \n\n"

printf "%%.f90: \n"
printf '\t .././lib_gen.x `echo $* | awk -F_ '\''{ print $$3" "$$4" "$$5 }'\''`'
printf " ${transpose_flavor} ${data_type} ${SIMD_size} ../${in_small_file}.out ../${in_tiny_file}.out > \$@\n\n"

printf "compile: \$(OBJFILES) \n\n"

#
# compile rule
#
printf "\$(OUTDIR)/%%.o: %%.f90 \n"
printf "\t ${target_compile} -c \$< -o \$@ \n\n"

printf "\$(OUTDIR)/\$(DRIVER): \n"
printf "\t ${target_compile} -c \$(notdir \$*).f90 -o \$@ \n\n"

printf "archive: ${archive} \n\n"

printf "${archive}: \$(OBJFILES) \$(OUTDIR)/\$(DRIVER) \n"
printf "\t ar -r \$@ \$^ \n\n"

) > ${make_file}

cd ${lib_dir}

#
# make the output dir
#
mkdir -p ${out_dir}

#
# execute makefile compiling all variants and executing them
#
make -B -j ${tasks} -f ../${make_file} ${target}

#
# Stop the execution if the output files were not produced
#
if [ "$target" = "source" -o "$target" = "compile" ]; then
    exit
fi

#
# a final test program. Checking correctness and final performance comparison
#

cat << EOF > test_smm_${filetrans}.f90
MODULE WTF
  INTERFACE MYRAND
    MODULE PROCEDURE SMYRAND, DMYRAND, CMYRAND, ZMYRAND
  END INTERFACE
CONTAINS
  SUBROUTINE DMYRAND(A)
    REAL(KIND=KIND(0.0D0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq
    CALL RANDOM_NUMBER(Aeq)
    A=Aeq
  END SUBROUTINE
  SUBROUTINE SMYRAND(A)
    REAL(KIND=KIND(0.0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq
    CALL RANDOM_NUMBER(Aeq)
    A=Aeq
  END SUBROUTINE
  SUBROUTINE CMYRAND(A)
    COMPLEX(KIND=KIND(0.0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq,Beq
    CALL RANDOM_NUMBER(Aeq)
    CALL RANDOM_NUMBER(Beq)
    A=CMPLX(Aeq,Beq,KIND=KIND(0.0))
  END SUBROUTINE
  SUBROUTINE ZMYRAND(A)
    COMPLEX(KIND=KIND(0.0D0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq,Beq
    CALL RANDOM_NUMBER(Aeq)
    CALL RANDOM_NUMBER(Beq)
    A=CMPLX(Aeq,Beq,KIND=KIND(0.0D0))
  END SUBROUTINE
END MODULE
SUBROUTINE testit(M,N,K)
  USE WTF
  IMPLICIT NONE
  INTEGER :: M,N,K

  $strdat :: C1(M,N), C2(M,N)
  $strdat :: ${decl}
  $strdat, PARAMETER :: one=1
  INTEGER :: i,LDA,LDB

  REAL(KIND=KIND(0.0D0)) :: flops,gflop
  REAL :: t1,t2,t3,t4
  INTEGER :: Niter

  flops=2*REAL(M,KIND=KIND(0.0D0))*N*K
  gflop=1000.0D0*1000.0D0*1000.0D0
  ! assume we would like to do 5 Gflop for testing a subroutine
  Niter=MAX(1,CEILING(MIN(10000000.0D0,5*gflop/flops)))
  ${lds}

  DO i=1,10
     CALL MYRAND(A)
     CALL MYRAND(B)
     CALL MYRAND(C1)
     C2=C1

     CALL ${gemm}("$ta","$tb",M,N,K,one,A,LDA,B,LDB,one,C1,M) 
     CALL smm_${filetrans}(M,N,K,A,B,C2)

     IF (MAXVAL(ABS(C2-C1))>100*EPSILON(REAL(1.0,KIND=KIND(A(1,1))))) THEN
        write(6,*) "Matrix size",M,N,K
        write(6,*) "A=",A
        write(6,*) "B=",B
        write(6,*) "C1=",C1
        write(6,*) "C2=",C2
        write(6,*) "BLAS and smm yield different results : possible compiler bug... do not use the library ;-)"
        STOP
     ENDIF
  ENDDO

  A=0; B=0; C1=0 ; C2=0
 
  CALL CPU_TIME(t1) 
  DO i=1,Niter
     CALL ${gemm}("$ta","$tb",M,N,K,one,A,LDA,B,LDB,one,C1,M) 
  ENDDO
  CALL CPU_TIME(t2) 

  CALL CPU_TIME(t3)
  DO i=1,Niter
     CALL smm_${filetrans}(M,N,K,A,B,C2)
  ENDDO
  CALL CPU_TIME(t4)

  WRITE(6,'(A,I5,I5,I5,A,F6.3,A,F6.3,A,F12.3,A)') "Matrix size ",M,N,K, &
        " smm: ",Niter*flops/(t4-t3)/gflop," Gflops. Linked blas: ",Niter*flops/(t2-t1)/gflop,&
        " Gflops. Performance ratio: ",((t2-t1)/(t4-t3))*100,"%"

END SUBROUTINE 

PROGRAM tester
  IMPLICIT NONE

EOF

for m in ${dims_small}  ; do
for n in ${dims_small}  ; do
for k in ${dims_small}  ; do
  echo "   CALL testit(${m},${n},${k})" >> test_smm_${filetrans}.f90
done ; done ; done

cat << EOF >> test_smm_${filetrans}.f90
  ! checking 'peak' performance (and a size likely outside of the library)
  CALL testit(1000,1000,1000)
END PROGRAM
EOF


#
# compile the benchmarking and testing program for the smm library
#
${target_compile} test_smm_${filetrans}.f90 -o test_smm_${filetrans}.x -L../lib -l${library} ${blas_linking}

#
# run and compile the final test
#
test_out_file="test_smm_${filetrans}"
if [ -n "${mic+xxxx}" ]; then
test_out_file="${test_out_file}_${mic}"
mic_cmd=ssh_mic_cmd
fi
test_out_file="${test_out_file}.out"

${mic_cmd} ./test_smm_${filetrans}.x | tee ../${test_out_file}

#
# We're done... protect the user from bad compilers
#
set +e
grep "BLAS and smm yield different results" ../${test_out_file} >& /dev/null
if [ "$?" == "0" ]; then
   echo "Library is miscompiled ... removing lib/libsmm_${filetrans}.a"
   rm -f ${archive}
else
   pathhere=`cd ..; pwd`
   echo "Done... check performance looking at ${test_out_file}"
   echo "Final library can be linked as -L${pathhere}/lib -l${library}"
fi

