import numpy as np
import gdal
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
from osgeo.gdalconst import *
gdal.AllRegister() # register all drivers
gdal.UseExceptions()

'''http://m...content-available-to-author-only...l.com/blog/archive/2012/5/2/understanding-raster-basic-gis-concepts-and-the-python-gdal-library/'''

#############
# Functions #
#############

def transform_utm_to_wgs84(easting, northing, zone):
    utm_coordinate_system = osr.SpatialReference()

    # Set geographic coordinate system to handle lat/lon
    utm_coordinate_system.SetWellKnownGeogCS("WGS84") 
    is_northern = northing > 0    
    utm_coordinate_system.SetUTM(zone, is_northern)

    # Clone ONLY the geographic coordinate system 
    wgs84_coordinate_system = utm_coordinate_system.CloneGeogCS() 
    
    # create transform component
    utm_to_wgs84_geo_transform = osr.CoordinateTransformation(utm_coordinate_system, wgs84_coordinate_system) # (, )

    # returns lon, lat, altitude
    return utm_to_wgs84_geo_transform.TransformPoint(easting, northing, 0) 

class WGS84Transform(object):
    # TODO decide whether to have init and/or call functions
    #def transform_wgs84_to_utm(lon, lat)

#    def __init__(self,lon,lat):
#    self.lon = lon
#	self.lat = lat

    def get_utm_zone(self,longitude):
	return (int(1+(longitude+180.0)/6.0))
	    
    def is_lat_northern(self,latitude):
	"""
	Determines if given latitude is a northern for UTM
	"""
	if (latitude < 0.0):
	    return 0
	else:
	    return 1

    def wgs84_to_utm(self,lon,lat):            
	    utm_coordinate_system = osr.SpatialReference()
	    # Set geographic coordinate system to handle lat/lon  
	    utm_coordinate_system.SetWellKnownGeogCS("WGS84") 
	    utm_coordinate_system.SetUTM(self.get_utm_zone(lon), self.is_lat_northern(lat))
	   
	    # Clone ONLY the geographic coordinate system  
	    wgs84_coordinate_system = utm_coordinate_system.CloneGeogCS() 
	    
	    # create transform component
	    wgs84_to_utm_geo_transform = osr.CoordinateTransformation(wgs84_coordinate_system, utm_coordinate_system) # (, )
	    # returns easting, northing, altitude
	    return wgs84_to_utm_geo_transform.TransformPoint(lon, lat, 0) 

def get_iterable_extent(*args):
    '''Returns list of minimum and maximum from lists/array input'''
    iterable_extent = list()
    for iter_object in args:
	iterable_extent.append(min(iter_object))
	iterable_extent.append(max(iter_object))
    return iterable_extent 


def get_raster_size(min_x, min_y, max_x, max_y, cell_width, cell_height):
    """Determine the number of rows/columns given the bounds of the point 
    data and the desired cell size"""

    print 'raster min_x:',min_x 
    print 'raster max_x:',max_x 
    print 'raster min_y:',min_y 
    print 'raster max_y:',max_y 
    cols = int((max_x - min_x) / cell_width)
    rows = int((max_y - min_y) / abs(cell_height))
    return cols, rows


def lonlat_to_pixel(lon, lat, inverse_geo_transform):
    """Translates the given lon, lat to the grid pixel coordinates
    in data array (zero start)"""

    wgs84_object = WGS84Transform()
    # transform to utm
    utm_x, utm_y, utm_z = wgs84_object.wgs84_to_utm(lon, lat)
    print 'utm_x:',utm_x
    print 'utm_y:',utm_y
    print 'utm_z:',utm_z

    # apply inverse geo tranform
    pixel_x, pixel_y = gdal.ApplyGeoTransform(inverse_geo_transform, utm_x, utm_y)
    print 'pixel_x:',pixel_x
    print 'pixel_y:',pixel_y
    pixel_x = int(pixel_x) - 1 # adjust to 0 start for array
    pixel_y = int(pixel_y) - 1 # adjust to 0 start for array


    return pixel_x, abs(pixel_y) # y pixel is likly a negative value given geo_transform


def create_raster(lons,lats,values,filename="test.tiff",output_format="GTiff"):      
    """lon/lat values in WGS84"""          

    # create empty raster
    driver = gdal.GetDriverByName(output_format)
    number_of_bands = 1
    band_type = gdal.GDT_Float32
   
    x_rotation = 0
    y_rotation = 0
    cell_width_meters = 50.0
    cell_height_meters = 50.0

    # retrieve bounds for point data           
    min_lon, max_lon, min_lat, max_lat = get_iterable_extent(lons,lats) 
    print 'min_lon:',min_lon 
    print 'max_lon:',max_lon 
    print 'min_lat:',min_lat 
    print 'max_lat:',max_lat 

    # Set geographic coordinate system to handle lat/lon        
    srs = osr.SpatialReference()
    srs.SetWellKnownGeogCS("WGS84")

    # Set projected coordinate system  to handle meters         
    wgs84_obj = WGS84Transform()
    srs.SetUTM(wgs84_obj.get_utm_zone(min_lon), wgs84_obj.is_lat_northern(max_lat)) 

    # create transforms for point conversion
    wgs84_coordinate_system = srs.CloneGeogCS() # clone only the geographic coordinate system   
    wgs84_to_utm_transform = osr.CoordinateTransformation(wgs84_coordinate_system, srs)
    utm_to_wgs84_transform = osr.CoordinateTransformation(srs, wgs84_coordinate_system)
   
    # convert to UTM
    top_left_x, top_left_y, z = wgs84_obj.wgs84_to_utm(min_lon, max_lat)    
    lower_right_x, lower_right_y, z = wgs84_obj.wgs84_to_utm(max_lon, min_lat)
    print 'top_left_x:',top_left_x
    print 'top_left_y:',top_left_y
    print 'lower_right_x:',lower_right_x
    print 'lower_right_y:',lower_right_y
    
    '''get_raster_size(min_x, min_y, max_x, max_y, cell_width, cell_height)'''
    cols, rows = get_raster_size(top_left_x,lower_right_y,
                                 lower_right_x,top_left_y,
                                 cell_width_meters,cell_height_meters)
    print cols, rows 
    dataset = driver.Create(filename, cols, rows, number_of_bands, band_type)
   
    # GeoTransform parameters
    # --> need to know the area that will be covered to define the geo tranform
    # top left x, w-e pixel resolution, rotation, top left y, rotation, n-s pixel resolution
    geo_transform = [ top_left_x, cell_width_meters, x_rotation, top_left_y, y_rotation, -cell_height_meters ] # cell height must be negative (-) to apply image space to map
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(srs.ExportToWkt())
    inverse_geo_transform = gdal.InvGeoTransform(geo_transform)[1] # for mapping lat/lon to pixel

    # get the empty raster data array
    band = dataset.GetRasterBand(1) # 1 == band index value
    data = band.ReadAsArray(0, 0, cols, rows).astype(np.cfloat)

    # TODO check that lat/lon don't need to get meshed
    # populate array values for output
    for lon, lat, value in zip(lons,lats,values):
        # apply value to array
	#TODO figure out why pixel_x is negative
        pixel_x, pixel_y = lonlat_to_pixel(lon, lat, inverse_geo_transform)
	print pixel_x, pixel_y, data.shape
        data[pixel_x][pixel_y] = value
       
    # write the updated data array to file
    band.WriteArray(data, 0, 0)
    band.SetNoDataValue(NULL_VALUE)
    band.FlushCache()
   
    # set dataset to None to "close" file
    dataset = None
    # TODO have function return pixel values for interpolation
    return pixel_x, pixel_y

#################
# Main Function #
#################

if __name__ == '__main__':
	# example coordinates
	lat = [45.3,56.2,23.4]
	lon = [134.6,128.7,111.9]
	val = [3,6,2]

	create_raster(lon,lat,val)