/*
 *  Copyright (c) 2008-2009 Cyrille Berger <cberger@cberger.net>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * either version 2, or (at your option) any later version of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

#include "FunctionCaller_p.h"

#include <llvm/ExecutionEngine/GenericValue.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/Function.h>

#include <llvm/DerivedTypes.h> // <- I don't understand why I need to include that file to be able to display llvm::Type on the standard output

#include "VirtualMachine_p.h"

#include "Debug.h"
#include "GTLCore/Function.h"
#include "GTLCore/Value.h"
#include <GTLCore/Type.h>
#include "GTLCore/FunctionCaller_p.h"
#include "LLVMBackend/GenerationContext_p.h"
#include "ModuleData_p.h"
#include "LLVMBackend/CodeGenerator_p.h"
#include <llvm/Instructions.h>
#include "Type_p.h"
#include <GTLCore/Parameter.h>
#include "PrimitiveTypesTraits_p.h"

using namespace LLVMBackend;
using namespace GTLCore;

struct FunctionCaller::Private {
  llvm::Function* function;
  llvm::Function* callFunction;
  const Function* gtlFunction;
  void* functionPtr;
  void (*func)( void*, const void*);
  gtl_int32 sizeOfArguments;
  int sizeOfReturn;
  std::vector<ValueToArray*> valueToArrays;
  ArrayToValue* returnConverter;
};

FunctionCaller::FunctionCaller(llvm::Function* llvmFunction, const Function* function, ModuleData* _data) : d(new Private)
{
  d->function = llvmFunction;
  d->gtlFunction = function;
  d->functionPtr = 0;
  d->sizeOfArguments = 0;
  
  // Create a function to call the function
  // void callFunction(void* result, void* params)
  // {
  //   Type1 arg1 = *(Type1*)(params); params += sizeof(Type1);
  //   Type2 arg2 = *(Type2*)(params); params += sizeof(Type2);
  //   ...
  //   TypeN argN = *(TypeN*)(params); params += sizeof(TypeN);
  //   *(TypeRet*)(result) = function(arg1, arg2, ..., argN );
  // }
  llvm::LLVMContext& context = llvmFunction->getContext();

  std::vector<llvm::Type*> params;
  llvm::Type* pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(context), 0);
  params.push_back(pointerType);
  params.push_back(pointerType);
  llvm::FunctionType *STy = llvm::FunctionType::get(llvm::Type::getVoidTy(context), params, false);
  llvm::Function *Stub = llvm::Function::Create(STy, llvm::Function::InternalLinkage, "", _data->llvmLinkedModule());

  // Init the code generation
  LLVMBackend::CodeGenerator _codeGenerator(_data);
  LLVMBackend::GenerationContext gc(&_codeGenerator, &context, Stub, 0, _data, _data->llvmLinkedModule());
  llvm::BasicBlock *StubBB = gc.createBasicBlock();
  
  // Read the arguments
  llvm::Function::arg_iterator arg_it = Stub->arg_begin();
  //   void* result = first arg;
  llvm::Value* arg_result = arg_it;
  //   void* params = second arg;
  ++arg_it;
  llvm::Value* arg_params = arg_it;
  std::vector<llvm::Value*> Args;
  
  // TypeI argI = *(TypeI*)(params); params += sizeof(TypeI);
  for(std::size_t i = 0; i < function->parameters().size(); ++i)
  {
    const Type* type = function->parameters()[i].type();
    llvm::Value* ptr = CodeGenerator::convertPointerTo(StubBB, llvm::GetElementPtrInst::Create( arg_params, CodeGenerator::integerToConstant(context, d->sizeOfArguments), "", StubBB ), type->d->type(context));
    Args.push_back(new llvm::LoadInst(ptr, "", StubBB));
    GTL_ASSERT(type->bitsSize() % 8 == 0);
    int typeSize = type->bitsSize() / 8;
    switch(type->dataType())
    {
    case Type::INTEGER32:
      d->valueToArrays.push_back( new ValueToArrayImpl<int>(d->sizeOfArguments) );
      break;
    case Type::FLOAT16:
    case Type::FLOAT32:
      d->valueToArrays.push_back( new ValueToArrayImpl<float>(d->sizeOfArguments));
      break;
    default:
      GTL_ABORT("Unimplemented");
    }
    d->sizeOfArguments += typeSize;
  }
  
  // function(arg1, arg2, ..., argN );
  llvm::CallInst *TheCall = llvm::CallInst::Create(llvmFunction, Args, "", StubBB);
  TheCall->setCallingConv(llvmFunction->getCallingConv());
  TheCall->setTailCall();

  GTL_DEBUG( *Stub );

  if(function->returnType() != Type::Void)
  {
    llvm::Type* retType = function->returnType()->d->type(context);
    llvm::Value* ptr = CodeGenerator::convertPointerTo(StubBB, arg_result, retType);
    new llvm::StoreInst(TheCall,ptr, StubBB);
  }
  llvm::ReturnInst::Create(context, StubBB);
  
  // Generate the function
  d->func = ( void(*)(void*, const void*)) GTLCore::VirtualMachine::instance()->getPointerToFunction( Stub );
  GTL_ASSERT(d->func); 
  
  // Create the return converter
  switch(function->returnType()->dataType())
  {
    case Type::INTEGER32:
      d->returnConverter = new ArrayToValueImpl<gtl_int32>();
      break;
    case Type::FLOAT16:
    case Type::FLOAT32:
      d->returnConverter = new ArrayToValueImpl<float>();
      break;
    case Type::VOID:
      d->returnConverter = 0;
      break;
    default:
      GTL_ABORT("Unimplemented");
  }
  d->sizeOfReturn = function->returnType()->bitsSize() / 8;
}

FunctionCaller::~FunctionCaller()
{
  delete d;
}

GTLCore::Value FunctionCaller::call(const std::vector<GTLCore::Value>& arguments )
{
  GTL_ASSERT(arguments.size() == d->valueToArrays.size());
  void* args_ptr = new char[d->sizeOfArguments]; // TODO make a thread cache of those
  void* return_ptr = new char[d->sizeOfReturn];
  
  for(std::size_t i = 0; i < arguments.size(); ++i)
  {
    d->valueToArrays[i]->store(args_ptr, arguments[i]);
  }
  
  (*d->func)(return_ptr, args_ptr);
  
  if(d->returnConverter)
  {
    return d->returnConverter->load(return_ptr);
  } else {
    return Value();
  }
}

