import  numpy as  np
import  gdal
from  osgeo import  gdal
from  osgeo import  osr
from  osgeo import  ogr
from  osgeo.gdalconst  import  *
gdal.AllRegister ( )  # register all drivers 
gdal.UseExceptions ( ) 
 
'''http://m...content-available-to-author-only...l.com/blog/archive/2012/5/2/understanding-raster-basic-gis-concepts-and-the-python-gdal-library/''' 
 
############# 
# Functions # 
############# 
 
def  transform_utm_to_wgs84( easting,  northing,  zone) :
    utm_coordinate_system =  osr.SpatialReference ( ) 
 
    # Set geographic coordinate system to handle lat/lon 
    utm_coordinate_system.SetWellKnownGeogCS ( "WGS84" )  
    is_northern =  northing >  0     
    utm_coordinate_system.SetUTM ( zone,  is_northern) 
 
    # Clone ONLY the geographic coordinate system  
    wgs84_coordinate_system =  utm_coordinate_system.CloneGeogCS ( )  
 
    # create transform component 
    utm_to_wgs84_geo_transform =  osr.CoordinateTransformation ( utm_coordinate_system,  wgs84_coordinate_system)  # (, ) 
 
    # returns lon, lat, altitude 
    return  utm_to_wgs84_geo_transform.TransformPoint ( easting,  northing,  0 )  
 
class  WGS84Transform( object ) :
    # TODO decide whether to have init and/or call functions 
    #def transform_wgs84_to_utm(lon, lat) 
 
#    def __init__(self,lon,lat): 
#    self.lon = lon 
#	self.lat = lat 
 
    def  get_utm_zone( self , longitude) :
	return  ( int ( 1 +( longitude+180.0 ) /6.0 ) ) 
 
    def  is_lat_northern( self , latitude) :
	""" 
	Determines if given latitude is a northern for UTM 
	""" 
	if  ( latitude <  0.0 ) :
	    return  0 
	else :
	    return  1 
 
    def  wgs84_to_utm( self , lon, lat) :            
	    utm_coordinate_system =  osr.SpatialReference ( ) 
	    # Set geographic coordinate system to handle lat/lon   
	    utm_coordinate_system.SetWellKnownGeogCS ( "WGS84" )  
	    utm_coordinate_system.SetUTM ( self .get_utm_zone ( lon) ,  self .is_lat_northern ( lat) ) 
 
	    # Clone ONLY the geographic coordinate system   
	    wgs84_coordinate_system =  utm_coordinate_system.CloneGeogCS ( )  
 
	    # create transform component 
	    wgs84_to_utm_geo_transform =  osr.CoordinateTransformation ( wgs84_coordinate_system,  utm_coordinate_system)  # (, ) 
	    # returns easting, northing, altitude 
	    return  wgs84_to_utm_geo_transform.TransformPoint ( lon,  lat,  0 )  
 
def  get_iterable_extent( *args) :
    '''Returns list of minimum and maximum from lists/array input''' 
    iterable_extent =  list ( ) 
    for  iter_object in  args:
	iterable_extent.append ( min ( iter_object) ) 
	iterable_extent.append ( max ( iter_object) ) 
    return  iterable_extent 
 
 
def  get_raster_size( min_x,  min_y,  max_x,  max_y,  cell_width,  cell_height) :
    """Determine the number of rows/columns given the bounds of the point  
    data and the desired cell size""" 
 
    print  'raster min_x:' , min_x 
    print  'raster max_x:' , max_x 
    print  'raster min_y:' , min_y 
    print  'raster max_y:' , max_y 
    cols =  int ( ( max_x - min_x)  / cell_width) 
    rows =  int ( ( max_y - min_y)  / abs ( cell_height) ) 
    return  cols,  rows
 
 
def  lonlat_to_pixel( lon,  lat,  inverse_geo_transform) :
    """Translates the given lon, lat to the grid pixel coordinates 
    in data array (zero start)""" 
 
    wgs84_object =  WGS84Transform( ) 
    # transform to utm 
    utm_x,  utm_y,  utm_z =  wgs84_object.wgs84_to_utm ( lon,  lat) 
    print  'utm_x:' , utm_x
    print  'utm_y:' , utm_y
    print  'utm_z:' , utm_z
 
    # apply inverse geo tranform 
    pixel_x,  pixel_y =  gdal.ApplyGeoTransform ( inverse_geo_transform,  utm_x,  utm_y) 
    print  'pixel_x:' , pixel_x
    print  'pixel_y:' , pixel_y
    pixel_x =  int ( pixel_x)  - 1  # adjust to 0 start for array 
    pixel_y =  int ( pixel_y)  - 1  # adjust to 0 start for array 
 
 
    return  pixel_x,  abs ( pixel_y)  # y pixel is likly a negative value given geo_transform 
 
 
def  create_raster( lons, lats, values, filename= "test.tiff" , output_format= "GTiff" ) :      
    """lon/lat values in WGS84"""           
 
    # create empty raster 
    driver =  gdal.GetDriverByName ( output_format) 
    number_of_bands =  1 
    band_type =  gdal.GDT_Float32 
 
    x_rotation =  0 
    y_rotation =  0 
    cell_width_meters =  50.0 
    cell_height_meters =  50.0 
 
    # retrieve bounds for point data            
    min_lon,  max_lon,  min_lat,  max_lat =  get_iterable_extent( lons, lats)  
    print  'min_lon:' , min_lon 
    print  'max_lon:' , max_lon 
    print  'min_lat:' , min_lat 
    print  'max_lat:' , max_lat 
 
    # Set geographic coordinate system to handle lat/lon         
    srs =  osr.SpatialReference ( ) 
    srs.SetWellKnownGeogCS ( "WGS84" ) 
 
    # Set projected coordinate system  to handle meters          
    wgs84_obj =  WGS84Transform( ) 
    srs.SetUTM ( wgs84_obj.get_utm_zone ( min_lon) ,  wgs84_obj.is_lat_northern ( max_lat) )  
 
    # create transforms for point conversion 
    wgs84_coordinate_system =  srs.CloneGeogCS ( )  # clone only the geographic coordinate system    
    wgs84_to_utm_transform =  osr.CoordinateTransformation ( wgs84_coordinate_system,  srs) 
    utm_to_wgs84_transform =  osr.CoordinateTransformation ( srs,  wgs84_coordinate_system) 
 
    # convert to UTM 
    top_left_x,  top_left_y,  z =  wgs84_obj.wgs84_to_utm ( min_lon,  max_lat)     
    lower_right_x,  lower_right_y,  z =  wgs84_obj.wgs84_to_utm ( max_lon,  min_lat) 
    print  'top_left_x:' , top_left_x
    print  'top_left_y:' , top_left_y
    print  'lower_right_x:' , lower_right_x
    print  'lower_right_y:' , lower_right_y
 
    '''get_raster_size(min_x, min_y, max_x, max_y, cell_width, cell_height)''' 
    cols,  rows =  get_raster_size( top_left_x, lower_right_y, 
                                 lower_right_x, top_left_y, 
                                 cell_width_meters, cell_height_meters) 
    print  cols,  rows 
    dataset =  driver.Create ( filename,  cols,  rows,  number_of_bands,  band_type) 
 
    # GeoTransform parameters 
    # --> need to know the area that will be covered to define the geo tranform 
    # top left x, w-e pixel resolution, rotation, top left y, rotation, n-s pixel resolution 
    geo_transform =  [  top_left_x,  cell_width_meters,  x_rotation,  top_left_y,  y_rotation,  -cell_height_meters ]  # cell height must be negative (-) to apply image space to map 
    dataset.SetGeoTransform ( geo_transform) 
    dataset.SetProjection ( srs.ExportToWkt ( ) ) 
    inverse_geo_transform =  gdal.InvGeoTransform ( geo_transform) [ 1 ]  # for mapping lat/lon to pixel 
 
    # get the empty raster data array 
    band =  dataset.GetRasterBand ( 1 )  # 1 == band index value 
    data =  band.ReadAsArray ( 0 ,  0 ,  cols,  rows) .astype ( np.cfloat ) 
 
    # TODO check that lat/lon don't need to get meshed 
    # populate array values for output 
    for  lon,  lat,  value in  zip ( lons, lats, values) :
        # apply value to array 
	#TODO figure out why pixel_x is negative 
        pixel_x,  pixel_y =  lonlat_to_pixel( lon,  lat,  inverse_geo_transform) 
	print  pixel_x,  pixel_y,  data.shape 
        data[ pixel_x] [ pixel_y]  =  value
 
    # write the updated data array to file 
    band.WriteArray ( data,  0 ,  0 ) 
    band.SetNoDataValue ( NULL_VALUE) 
    band.FlushCache ( ) 
 
    # set dataset to None to "close" file 
    dataset =  None 
    # TODO have function return pixel values for interpolation 
    return  pixel_x,  pixel_y
 
################# 
# Main Function # 
################# 
 
if  __name__ ==  '__main__' :
	# example coordinates 
	lat =  [ 45.3 , 56.2 , 23.4 ] 
	lon =  [ 134.6 , 128.7 , 111.9 ] 
	val =  [ 3 , 6 , 2 ] 
 
	create_raster( lon, lat, val) 
