import  numpy as  np
import  gdal
from  osgeo import  gdal
from  osgeo import  osr
from  osgeo import  ogr
from  osgeo.gdalconst  import  *
gdal.AllRegister ( )  # register all drivers 
gdal.UseExceptions ( ) 
 
'''http://m...content-available-to-author-only...l.com/blog/archive/2012/5/2/understanding-raster-basic-gis-concepts-and-the-python-gdal-library/''' 
 
############# 
# Functions # 
############# 
 
def  transform_utm_to_wgs84( easting,  northing,  zone) :
    utm_coordinate_system =  osr.SpatialReference ( ) 
 
    # Set geographic coordinate system to handle lat/lon 
    utm_coordinate_system.SetWellKnownGeogCS ( "WGS84" )  
    is_northern =  northing >  0     
    utm_coordinate_system.SetUTM ( zone,  is_northern) 
 
    # Clone ONLY the geographic coordinate system  
    wgs84_coordinate_system =  utm_coordinate_system.CloneGeogCS ( )  
 
    # create transform component 
    utm_to_wgs84_geo_transform =  osr.CoordinateTransformation ( utm_coordinate_system,  wgs84_coordinate_system)  # (, ) 
 
    # returns lon, lat, altitude 
    return  utm_to_wgs84_geo_transform.TransformPoint ( easting,  northing,  0 )  
 
class  WGS84Transform( object ) :
    # TODO decide whether to have init and/or call functions 
    #def transform_wgs84_to_utm(lon, lat) 
 
#    def __init__(self,lon,lat): 
#    self.lon = lon 
#	self.lat = lat 
 
    def  get_utm_zone( self , longitude) :
	return  ( int ( 1 +( longitude+180.0 ) /6.0 ) ) 
 
    def  is_lat_northern( self , latitude) :
	""" 
	Determines if given latitude is a northern for UTM 
	""" 
	if  ( latitude <  0.0 ) :
	    return  0 
	else :
	    return  1 
 
    def  wgs84_to_utm( self , lon, lat) :            
	    utm_coordinate_system =  osr.SpatialReference ( ) 
	    # Set geographic coordinate system to handle lat/lon   
	    utm_coordinate_system.SetWellKnownGeogCS ( "WGS84" )  
	    utm_coordinate_system.SetUTM ( self .get_utm_zone ( lon) ,  self .is_lat_northern ( lat) ) 
 
	    # Clone ONLY the geographic coordinate system   
	    wgs84_coordinate_system =  utm_coordinate_system.CloneGeogCS ( )  
 
	    # create transform component 
	    wgs84_to_utm_geo_transform =  osr.CoordinateTransformation ( wgs84_coordinate_system,  utm_coordinate_system)  # (, ) 
	    # returns easting, northing, altitude 
	    return  wgs84_to_utm_geo_transform.TransformPoint ( lon,  lat,  0 )  
 
def  get_iterable_extent( *args) :
    '''Returns list of minimum and maximum from lists/array input''' 
    iterable_extent =  list ( ) 
    for  iter_object in  args:
	iterable_extent.append ( min ( iter_object) ) 
	iterable_extent.append ( max ( iter_object) ) 
    return  iterable_extent 
 
 
def  get_raster_size( min_x,  min_y,  max_x,  max_y,  cell_width,  cell_height) :
    """Determine the number of rows/columns given the bounds of the point  
    data and the desired cell size""" 
 
    print  'raster min_x:' , min_x 
    print  'raster max_x:' , max_x 
    print  'raster min_y:' , min_y 
    print  'raster max_y:' , max_y 
    cols =  int ( ( max_x - min_x)  / cell_width) 
    rows =  int ( ( max_y - min_y)  / abs ( cell_height) ) 
    return  cols,  rows
 
 
def  lonlat_to_pixel( lon,  lat,  inverse_geo_transform) :
    """Translates the given lon, lat to the grid pixel coordinates 
    in data array (zero start)""" 
 
    wgs84_object =  WGS84Transform( ) 
    # transform to utm 
    utm_x,  utm_y,  utm_z =  wgs84_object.wgs84_to_utm ( lon,  lat) 
    print  'utm_x:' , utm_x
    print  'utm_y:' , utm_y
    print  'utm_z:' , utm_z
 
    # apply inverse geo tranform 
    pixel_x,  pixel_y =  gdal.ApplyGeoTransform ( inverse_geo_transform,  utm_x,  utm_y) 
    print  'pixel_x:' , pixel_x
    print  'pixel_y:' , pixel_y
    pixel_x =  int ( pixel_x)  - 1  # adjust to 0 start for array 
    pixel_y =  int ( pixel_y)  - 1  # adjust to 0 start for array 
 
 
    return  pixel_x,  abs ( pixel_y)  # y pixel is likly a negative value given geo_transform 
 
 
def  create_raster( lons, lats, values, filename= "test.tiff" , output_format= "GTiff" ) :      
    """lon/lat values in WGS84"""           
 
    # create empty raster 
    driver =  gdal.GetDriverByName ( output_format) 
    number_of_bands =  1 
    band_type =  gdal.GDT_Float32 
 
    x_rotation =  0 
    y_rotation =  0 
    cell_width_meters =  50.0 
    cell_height_meters =  50.0 
 
    # retrieve bounds for point data            
    min_lon,  max_lon,  min_lat,  max_lat =  get_iterable_extent( lons, lats)  
    print  'min_lon:' , min_lon 
    print  'max_lon:' , max_lon 
    print  'min_lat:' , min_lat 
    print  'max_lat:' , max_lat 
 
    # Set geographic coordinate system to handle lat/lon         
    srs =  osr.SpatialReference ( ) 
    srs.SetWellKnownGeogCS ( "WGS84" ) 
 
    # Set projected coordinate system  to handle meters          
    wgs84_obj =  WGS84Transform( ) 
    srs.SetUTM ( wgs84_obj.get_utm_zone ( min_lon) ,  wgs84_obj.is_lat_northern ( max_lat) )  
 
    # create transforms for point conversion 
    wgs84_coordinate_system =  srs.CloneGeogCS ( )  # clone only the geographic coordinate system    
    wgs84_to_utm_transform =  osr.CoordinateTransformation ( wgs84_coordinate_system,  srs) 
    utm_to_wgs84_transform =  osr.CoordinateTransformation ( srs,  wgs84_coordinate_system) 
 
    # convert to UTM 
    top_left_x,  top_left_y,  z =  wgs84_obj.wgs84_to_utm ( min_lon,  max_lat)     
    lower_right_x,  lower_right_y,  z =  wgs84_obj.wgs84_to_utm ( max_lon,  min_lat) 
    print  'top_left_x:' , top_left_x
    print  'top_left_y:' , top_left_y
    print  'lower_right_x:' , lower_right_x
    print  'lower_right_y:' , lower_right_y
 
    '''get_raster_size(min_x, min_y, max_x, max_y, cell_width, cell_height)''' 
    cols,  rows =  get_raster_size( top_left_x, lower_right_y, 
                                 lower_right_x, top_left_y, 
                                 cell_width_meters, cell_height_meters) 
    print  cols,  rows 
    dataset =  driver.Create ( filename,  cols,  rows,  number_of_bands,  band_type) 
 
    # GeoTransform parameters 
    # --> need to know the area that will be covered to define the geo tranform 
    # top left x, w-e pixel resolution, rotation, top left y, rotation, n-s pixel resolution 
    geo_transform =  [  top_left_x,  cell_width_meters,  x_rotation,  top_left_y,  y_rotation,  -cell_height_meters ]  # cell height must be negative (-) to apply image space to map 
    dataset.SetGeoTransform ( geo_transform) 
    dataset.SetProjection ( srs.ExportToWkt ( ) ) 
    inverse_geo_transform =  gdal.InvGeoTransform ( geo_transform) [ 1 ]  # for mapping lat/lon to pixel 
 
    # get the empty raster data array 
    band =  dataset.GetRasterBand ( 1 )  # 1 == band index value 
    data =  band.ReadAsArray ( 0 ,  0 ,  cols,  rows) .astype ( np.cfloat ) 
 
    # TODO check that lat/lon don't need to get meshed 
    # populate array values for output 
    for  lon,  lat,  value in  zip ( lons, lats, values) :
        # apply value to array 
	#TODO figure out why pixel_x is negative 
        pixel_x,  pixel_y =  lonlat_to_pixel( lon,  lat,  inverse_geo_transform) 
	print  pixel_x,  pixel_y,  data.shape 
        data[ pixel_x] [ pixel_y]  =  value
 
    # write the updated data array to file 
    band.WriteArray ( data,  0 ,  0 ) 
    band.SetNoDataValue ( NULL_VALUE) 
    band.FlushCache ( ) 
 
    # set dataset to None to "close" file 
    dataset =  None 
    # TODO have function return pixel values for interpolation 
    return  pixel_x,  pixel_y
 
################# 
# Main Function # 
################# 
 
if  __name__ ==  '__main__' :
	# example coordinates 
	lat =  [ 45.3 , 56.2 , 23.4 ] 
	lon =  [ 134.6 , 128.7 , 111.9 ] 
	val =  [ 3 , 6 , 2 ] 
 
	create_raster( lon, lat, val) 
import numpy as np
import gdal
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
from osgeo.gdalconst import *
gdal.AllRegister() # register all drivers
gdal.UseExceptions()

'''http://m...content-available-to-author-only...l.com/blog/archive/2012/5/2/understanding-raster-basic-gis-concepts-and-the-python-gdal-library/'''

#############
# Functions #
#############

def transform_utm_to_wgs84(easting, northing, zone):
    utm_coordinate_system = osr.SpatialReference()

    # Set geographic coordinate system to handle lat/lon
    utm_coordinate_system.SetWellKnownGeogCS("WGS84") 
    is_northern = northing > 0    
    utm_coordinate_system.SetUTM(zone, is_northern)

    # Clone ONLY the geographic coordinate system 
    wgs84_coordinate_system = utm_coordinate_system.CloneGeogCS() 
    
    # create transform component
    utm_to_wgs84_geo_transform = osr.CoordinateTransformation(utm_coordinate_system, wgs84_coordinate_system) # (, )

    # returns lon, lat, altitude
    return utm_to_wgs84_geo_transform.TransformPoint(easting, northing, 0) 

class WGS84Transform(object):
    # TODO decide whether to have init and/or call functions
    #def transform_wgs84_to_utm(lon, lat)

#    def __init__(self,lon,lat):
#    self.lon = lon
#	self.lat = lat

    def get_utm_zone(self,longitude):
	return (int(1+(longitude+180.0)/6.0))
	    
    def is_lat_northern(self,latitude):
	"""
	Determines if given latitude is a northern for UTM
	"""
	if (latitude < 0.0):
	    return 0
	else:
	    return 1

    def wgs84_to_utm(self,lon,lat):            
	    utm_coordinate_system = osr.SpatialReference()
	    # Set geographic coordinate system to handle lat/lon  
	    utm_coordinate_system.SetWellKnownGeogCS("WGS84") 
	    utm_coordinate_system.SetUTM(self.get_utm_zone(lon), self.is_lat_northern(lat))
	   
	    # Clone ONLY the geographic coordinate system  
	    wgs84_coordinate_system = utm_coordinate_system.CloneGeogCS() 
	    
	    # create transform component
	    wgs84_to_utm_geo_transform = osr.CoordinateTransformation(wgs84_coordinate_system, utm_coordinate_system) # (, )
	    # returns easting, northing, altitude
	    return wgs84_to_utm_geo_transform.TransformPoint(lon, lat, 0) 

def get_iterable_extent(*args):
    '''Returns list of minimum and maximum from lists/array input'''
    iterable_extent = list()
    for iter_object in args:
	iterable_extent.append(min(iter_object))
	iterable_extent.append(max(iter_object))
    return iterable_extent 


def get_raster_size(min_x, min_y, max_x, max_y, cell_width, cell_height):
    """Determine the number of rows/columns given the bounds of the point 
    data and the desired cell size"""

    print 'raster min_x:',min_x 
    print 'raster max_x:',max_x 
    print 'raster min_y:',min_y 
    print 'raster max_y:',max_y 
    cols = int((max_x - min_x) / cell_width)
    rows = int((max_y - min_y) / abs(cell_height))
    return cols, rows


def lonlat_to_pixel(lon, lat, inverse_geo_transform):
    """Translates the given lon, lat to the grid pixel coordinates
    in data array (zero start)"""

    wgs84_object = WGS84Transform()
    # transform to utm
    utm_x, utm_y, utm_z = wgs84_object.wgs84_to_utm(lon, lat)
    print 'utm_x:',utm_x
    print 'utm_y:',utm_y
    print 'utm_z:',utm_z

    # apply inverse geo tranform
    pixel_x, pixel_y = gdal.ApplyGeoTransform(inverse_geo_transform, utm_x, utm_y)
    print 'pixel_x:',pixel_x
    print 'pixel_y:',pixel_y
    pixel_x = int(pixel_x) - 1 # adjust to 0 start for array
    pixel_y = int(pixel_y) - 1 # adjust to 0 start for array


    return pixel_x, abs(pixel_y) # y pixel is likly a negative value given geo_transform


def create_raster(lons,lats,values,filename="test.tiff",output_format="GTiff"):      
    """lon/lat values in WGS84"""          

    # create empty raster
    driver = gdal.GetDriverByName(output_format)
    number_of_bands = 1
    band_type = gdal.GDT_Float32
   
    x_rotation = 0
    y_rotation = 0
    cell_width_meters = 50.0
    cell_height_meters = 50.0

    # retrieve bounds for point data           
    min_lon, max_lon, min_lat, max_lat = get_iterable_extent(lons,lats) 
    print 'min_lon:',min_lon 
    print 'max_lon:',max_lon 
    print 'min_lat:',min_lat 
    print 'max_lat:',max_lat 

    # Set geographic coordinate system to handle lat/lon        
    srs = osr.SpatialReference()
    srs.SetWellKnownGeogCS("WGS84")

    # Set projected coordinate system  to handle meters         
    wgs84_obj = WGS84Transform()
    srs.SetUTM(wgs84_obj.get_utm_zone(min_lon), wgs84_obj.is_lat_northern(max_lat)) 

    # create transforms for point conversion
    wgs84_coordinate_system = srs.CloneGeogCS() # clone only the geographic coordinate system   
    wgs84_to_utm_transform = osr.CoordinateTransformation(wgs84_coordinate_system, srs)
    utm_to_wgs84_transform = osr.CoordinateTransformation(srs, wgs84_coordinate_system)
   
    # convert to UTM
    top_left_x, top_left_y, z = wgs84_obj.wgs84_to_utm(min_lon, max_lat)    
    lower_right_x, lower_right_y, z = wgs84_obj.wgs84_to_utm(max_lon, min_lat)
    print 'top_left_x:',top_left_x
    print 'top_left_y:',top_left_y
    print 'lower_right_x:',lower_right_x
    print 'lower_right_y:',lower_right_y
    
    '''get_raster_size(min_x, min_y, max_x, max_y, cell_width, cell_height)'''
    cols, rows = get_raster_size(top_left_x,lower_right_y,
                                 lower_right_x,top_left_y,
                                 cell_width_meters,cell_height_meters)
    print cols, rows 
    dataset = driver.Create(filename, cols, rows, number_of_bands, band_type)
   
    # GeoTransform parameters
    # --> need to know the area that will be covered to define the geo tranform
    # top left x, w-e pixel resolution, rotation, top left y, rotation, n-s pixel resolution
    geo_transform = [ top_left_x, cell_width_meters, x_rotation, top_left_y, y_rotation, -cell_height_meters ] # cell height must be negative (-) to apply image space to map
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(srs.ExportToWkt())
    inverse_geo_transform = gdal.InvGeoTransform(geo_transform)[1] # for mapping lat/lon to pixel

    # get the empty raster data array
    band = dataset.GetRasterBand(1) # 1 == band index value
    data = band.ReadAsArray(0, 0, cols, rows).astype(np.cfloat)

    # TODO check that lat/lon don't need to get meshed
    # populate array values for output
    for lon, lat, value in zip(lons,lats,values):
        # apply value to array
	#TODO figure out why pixel_x is negative
        pixel_x, pixel_y = lonlat_to_pixel(lon, lat, inverse_geo_transform)
	print pixel_x, pixel_y, data.shape
        data[pixel_x][pixel_y] = value
       
    # write the updated data array to file
    band.WriteArray(data, 0, 0)
    band.SetNoDataValue(NULL_VALUE)
    band.FlushCache()
   
    # set dataset to None to "close" file
    dataset = None
    # TODO have function return pixel values for interpolation
    return pixel_x, pixel_y

#################
# Main Function #
#################

if __name__ == '__main__':
	# example coordinates
	lat = [45.3,56.2,23.4]
	lon = [134.6,128.7,111.9]
	val = [3,6,2]

	create_raster(lon,lat,val)