# -*- coding: utf-8 -*-
"""QGIS Unit tests for the postgres raster provider.

Note: to prepare the DB, you need to run the sql files specified in
tests/testdata/provider/testdata_pg.sh

Read tests/README.md about writing/launching tests with PostgreSQL.

Run with ctest -V -R PyQgsPostgresProvider

.. note:: This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

"""
from builtins import next
__author__ = 'Alessandro Pasotti'
__date__ = '2019-12-20'
__copyright__ = 'Copyright 2019, The QGIS Project'

import qgis  # NOQA
import os
import time

from qgis.core import (
    QgsSettings,
    QgsReadWriteContext,
    QgsRectangle,
    QgsCoordinateReferenceSystem,
    QgsProject,
    QgsRasterLayer,
    QgsPointXY,
    QgsRaster,
    QgsProviderRegistry,
)
from qgis.testing import start_app, unittest
from utilities import unitTestDataPath, compareWkt

QGISAPP = start_app()
TEST_DATA_DIR = unitTestDataPath()


class TestPyQgsPostgresRasterProvider(unittest.TestCase):

    @classmethod
    def _load_test_table(cls, schemaname, tablename, basename=None):

        postgres_conn = cls.dbconn + " sslmode=disable "
        md = QgsProviderRegistry.instance().providerMetadata('postgres')
        conn = md.createConnection(postgres_conn, {})

        if basename is None:
            basename = tablename

        if tablename not in [n.tableName() for n in conn.tables(schemaname)]:
            with open(os.path.join(TEST_DATA_DIR, 'provider', 'postgresraster', basename + '.sql'), 'r') as f:
                sql = f.read()
                conn.executeSql(sql)
            assert (tablename in [n.tableName() for n in conn.tables(schemaname)])

    @classmethod
    def setUpClass(cls):
        """Run before all tests"""
        cls.dbconn = 'service=qgis_test'
        if 'QGIS_PGTEST_DB' in os.environ:
            cls.dbconn = os.environ['QGIS_PGTEST_DB']
        # Create test layers
        cls._load_test_table('public', 'raster_tiled_3035')
        cls.rl = QgsRasterLayer(cls.dbconn + ' sslmode=disable key=\'rid\' srid=3035  table="public"."raster_tiled_3035" sql=', 'test', 'postgresraster')
        assert cls.rl.isValid()
        cls.source = cls.rl.dataProvider()

    def gdal_block_compare(self, rlayer, band, extent, width, height, value):
        """Compare a block result with GDAL raster"""

        uri = rlayer.uri()
        gdal_uri = "PG: dbname={dbname} mode=2 host={host} port={port} table={table} schema={schema} sslmode=disable".format(**{
            'dbname': uri.database(),
            'host': uri.host(),
            'port': uri.port(),
            'table': uri.table(),
            'schema': uri.schema()
        })
        gdal_rl = QgsRasterLayer(gdal_uri, "rl", "gdal")
        self.assertTrue(gdal_rl.isValid())
        self.assertEqual(value, gdal_rl.dataProvider().block(band, self.rl.extent(), 6, 5).data().toHex())

    @classmethod
    def tearDownClass(cls):
        """Run after all tests"""

    def testExtent(self):
        extent = self.rl.extent()
        self.assertEqual(extent, QgsRectangle(4080050, 2430625, 4080200, 2430750))

    def testSize(self):
        self.assertEqual(self.source.xSize(), 6)
        self.assertEqual(self.source.ySize(), 5)

    def testCrs(self):
        self.assertEqual(self.source.crs().authid(), 'EPSG:3035')

    def testGetData(self):
        identify = self.source.identify(QgsPointXY(4080137.9, 2430687.9), QgsRaster.IdentifyFormatValue)
        expected = 192.51044
        self.assertAlmostEqual(identify.results()[1], expected, 4)

    def testBlockTiled(self):

        expected = b'6a610843880b0e431cc2194306342543b7633c43861858436e0a1143bbad194359612743a12b334317be4343dece59432b621b43f0e42843132b3843ac824043e6cf48436e465a435c4d2d430fa63d43f87a4843b5494a4349454e4374f35b43906e41433ab54c43b056504358575243b1ec574322615f43'
        block = self.source.block(1, self.rl.extent(), 6, 5)
        actual = block.data().toHex()
        self.assertEqual(len(actual), len(expected))
        self.assertEqual(actual, expected)

    def testNoConstraintRaster(self):
        """Read unconstrained raster layer"""

        self._load_test_table('public', 'raster_3035_no_constraints')
        rl = QgsRasterLayer(self.dbconn + ' sslmode=disable key=\'pk\' srid=3035  table="public"."raster_3035_no_constraints" sql=', 'test', 'postgresraster')
        self.assertTrue(rl.isValid())

    def testPkGuessing(self):
        """Read raster layer with no pkey in uri"""

        self._load_test_table('public', 'raster_tiled_3035')
        rl = QgsRasterLayer(self.dbconn + ' sslmode=disable srid=3035  table="public"."raster_tiled_3035" sql=', 'test', 'postgresraster')
        self.assertTrue(rl.isValid())

    def testWhereCondition(self):
        """Read raster layer with where condition"""

        self._load_test_table('public', 'raster_3035_tiled_no_overviews')
        rl_nowhere = QgsRasterLayer(self.dbconn + ' sslmode=disable srid=3035  table="public"."raster_3035_tiled_no_overviews"' +
                                    'sql=', 'test', 'postgresraster')
        self.assertTrue(rl_nowhere.isValid())

        rl = QgsRasterLayer(self.dbconn + ' sslmode=disable srid=3035  table="public"."raster_3035_tiled_no_overviews"' +
                            'sql="category" = \'cat2\'', 'test', 'postgresraster')
        self.assertTrue(rl.isValid())

        self.assertTrue(not rl.extent().isEmpty())
        self.assertNotEqual(rl_nowhere.extent(), rl.extent())

        self.assertIsNone(rl.dataProvider().identify(QgsPointXY(4080137.9, 2430687.9), QgsRaster.IdentifyFormatValue).results()[1])
        self.assertIsNotNone(rl_nowhere.dataProvider().identify(QgsPointXY(4080137.9, 2430687.9), QgsRaster.IdentifyFormatValue).results()[1])

        self.assertAlmostEqual(rl.dataProvider().identify(rl.extent().center(), QgsRaster.IdentifyFormatValue).results()[1], 223.38, 2)

    def testNoPk(self):
        """Read raster with no PK"""

        self._load_test_table('public', 'raster_3035_tiled_no_pk')
        rl = QgsRasterLayer(self.dbconn + ' sslmode=disable srid=3035  table="public"."raster_3035_tiled_no_pk"' +
                            'sql=', 'test', 'postgresraster')
        self.assertTrue(rl.isValid())

    def testCompositeKey(self):
        """Read raster with composite pks"""

        self._load_test_table('public', 'raster_3035_tiled_composite_pk')
        rl = QgsRasterLayer(self.dbconn + ' sslmode=disable srid=3035  table="public"."raster_3035_tiled_composite_pk"' +
                            'sql=', 'test', 'postgresraster')
        self.assertTrue(rl.isValid())
        data = rl.dataProvider().block(1, rl.extent(), 3, 3)
        self.assertEqual(int(data.value(0, 0)), 142)

    @unittest.skip('Performance test is disabled in Travis environment')
    def testSpeed(self):
        """Compare speed with GDAL provider, this test was used during development"""

        conn = "user={user} host=localhost port=5432 password={password} dbname={speed_db} ".format(
            user=os.environ.get('USER'),
            password=os.environ.get('USER'),
            speed_db='qgis_test'
        )

        table = 'basic_map_tiled'
        schema = 'public'

        def _speed_check(schema, table, width, height):

            print('-' * 80)
            print("Testing: {schema}.{table}".format(table=table, schema=schema))
            print('-' * 80)

            # GDAL
            start = time.time()
            rl = QgsRasterLayer("PG: " + conn + "table={table} mode=2 schema={schema}".format(table=table, schema=schema), 'gdal_layer', 'gdal')
            self.assertTrue(rl.isValid())
            # Make is smaller than full extent
            extent = rl.extent().buffered(-rl.extent().width() * 0.2)
            checkpoint_1 = time.time()
            print("Tiled GDAL start time: {:.6f}".format(checkpoint_1 - start))
            rl.dataProvider().block(1, extent, width, height)
            checkpoint_2 = time.time()
            print("Tiled GDAL first block time: {:.6f}".format(checkpoint_2 - checkpoint_1))
            #rl.dataProvider().block(1, extent, width, height)
            checkpoint_3 = time.time()
            print("Tiled GDAL second block time: {:.6f}".format(checkpoint_3 - checkpoint_2))
            print("Total GDAL time: {:.6f}".format(checkpoint_3 - start))
            print('-' * 80)

            # PG native
            start = time.time()
            rl = QgsRasterLayer(conn + "table={table} schema={schema}".format(table=table, schema=schema), 'gdal_layer', 'postgresraster')
            self.assertTrue(rl.isValid())
            extent = rl.extent().buffered(-rl.extent().width() * 0.2)
            checkpoint_1 = time.time()
            print("Tiled PG start time: {:.6f}".format(checkpoint_1 - start))
            rl.dataProvider().block(1, extent, width, height)
            checkpoint_2 = time.time()
            print("Tiled PG first block time: {:.6f}".format(checkpoint_2 - checkpoint_1))
            rl.dataProvider().block(1, extent, width, height)
            checkpoint_3 = time.time()
            print("Tiled PG second block time: {:.6f}".format(checkpoint_3 - checkpoint_2))
            print("Total PG time: {:.6f}".format(checkpoint_3 - start))
            print('-' * 80)

        _speed_check(schema, table, 1000, 1000)

    def testOtherSchema(self):
        """Test that a layer in a different schema than public can be loaded
        See: GH #34823"""

        self._load_test_table('idro', 'cosmo_i5_snow', 'bug_34823_pg_raster')

        rl = QgsRasterLayer(self.dbconn + " sslmode=disable table={table} schema={schema}".format(table='cosmo_i5_snow', schema='idro'), 'pg_layer', 'postgresraster')
        self.assertTrue(rl.isValid())
        self.assertTrue(compareWkt(rl.extent().asWktPolygon(), 'POLYGON((-64.79286766849691048 -77.26689086732433509, -62.18292922825105506 -77.26689086732433509, -62.18292922825105506 -74.83694818157819384, -64.79286766849691048 -74.83694818157819384, -64.79286766849691048 -77.26689086732433509))'))

    def testUntiledMultipleRows(self):
        """Test multiple rasters (one per row)"""

        self._load_test_table('public', 'raster_3035_untiled_multiple_rows')

        rl = QgsRasterLayer(self.dbconn + " sslmode=disable table={table} schema={schema} sql=\"pk\" = 1".format(table='raster_3035_untiled_multiple_rows', schema='public'), 'pg_layer', 'postgresraster')
        self.assertTrue(rl.isValid())
        block = rl.dataProvider().block(1, rl.extent(), 2, 2)
        data = []
        for i in range(2):
            for j in range(2):
                data.append(int(block.value(i, j)))
        self.assertEqual(data, [136, 142, 145, 153])

        rl = QgsRasterLayer(self.dbconn + " sslmode=disable table={table} schema={schema} sql=\"pk\" = 2".format(table='raster_3035_untiled_multiple_rows', schema='public'), 'pg_layer', 'postgresraster')
        self.assertTrue(rl.isValid())
        block = rl.dataProvider().block(1, rl.extent(), 2, 2)
        data = []
        for i in range(2):
            for j in range(2):
                data.append(int(block.value(i, j)))
        self.assertEqual(data, [136, 142, 161, 169])


if __name__ == '__main__':
    unittest.main()
