#!/usr/bin/FEWXPYTHON

########### imports

from datetime import datetime
import glob
import pyart
import subprocess
import numpy as np
import math
from scipy.interpolate import griddata
import xarray as xr
from cfgrib.xarray_to_grib import to_grib
#import os; os.environ['ESMFMKFILE'] = '/fewxops/software/anaconda3/envs/fewx_operational_3.10.8/lib/esmf.mk'; import xesmf


########### constants
Rearth=6378137                  # radius of earth in meters
pi=math.pi                      # pi 3.1415.....
degPerMeter=360/(2*pi*Rearth)   # degrees lon per meter
too_old=40 			# minutes
now = datetime.utcnow()		# to help determine how old a rad file might be in minutes
NANVAL=-15.0			# what to set nan values to in radar composite routine 0.0 not great anything less than 0 works -- the more less than zero the slower so -15.0 was chosen 
########### functions
def EXITING(): exit()


### obtain list of rad data files in $LATEST/rad
raddir="/fewxops/data/latest/rad/nexrad3datatest"
raddir="/fewxops/scripts/rad/nexrad3datatest" ## XXX
print(str(datetime.now()) + " - reading radar files from "+raddir)
print(str(datetime.now())+" - Now  UTC "+str(now))
nexrad3files=glob.glob(raddir+"/sn.last.*")
if (len(nexrad3files)) == 0: 
	print ("No sn.last.* files found in "+ raddir)
	EXITING()


### create master list of both avairads (by id) and a master array of available radar sites -- each element of the master array holds all data per radar site (gets fed into pyart to generate composite)
availrads=[];master_array_radsites=[];decluttereddbz={} ;key=0

for radfile in nexrad3files: 
	try: 
		radardata=pyart.io.read_nexrad_level3(radfile)
	except:
		print(radfile+"being skipped -- seems corrupt")
		continue

	radtime=datetime.strptime(radardata.time['units'][14:34].replace("T"," ").replace("Z",""),"%Y-%m-%d %H:%M:%S")
	radid=subprocess.run(["sed -n 2p "+radfile+" | cut -c 4-6"],shell=True,capture_output=True,text=True).stdout.rstrip()
	minutesold=(now-radtime).total_seconds()/60

	#### xxxx
	minutesold=1.0
	if radid == "BMX" or radid == "MSX" or radid == "KJK" or radid == "KSG": minutesold = 100.0
	### xxx	

	if minutesold <= too_old:
		master_array_radsites.append(radardata)
		availrads.append(radid)
		# declutter the rad field -- via source code comments in https://arm-doe.github.io/pyart/_modules/pyart/correct/despeckle.html#despeckle_field
		# the "Size" setting is a threshold number of clustered gates, below this threshold is considered noise and filtered -- we test 1 9 25 (grid pts) and 9 seemed to be the best
		decluttereddbz[key]=pyart.correct.despeckle_field(radardata, 'reflectivity', size=9)
		key += 1
	else:	
		print(str(datetime.now())+" - "+radid+"  UTC "+str(radtime)+" - ignored - older than "+str(too_old)+" minutes (actual oldness="+str(int(minutesold))+" minutes)")

if len(availrads) == 0: 
	print ("There were no radar files with a time stamp less than "+str(too_old)+" minutes old")
	EXITING()

print(str(datetime.now()) + " - "+str(len(availrads))+" radar sites will be composited:",*sorted(availrads))


##### prelim critical vars for pyart composite
print(str(datetime.now()) + " - generating composite -- at 1km res can take over 2 minutes")
US_centlat=38; US_centlon=-98		# US composite central lat lon
US_ew_meters = 4736                     # domain distance in km east/west across CONUS 
US_ns_meters = 3000                     # domain distance in km north/south CONUS
compositeres = 1.5                      # desired resolution of radar composite in km
i=int((US_ew_meters/ compositeres) + 1) # pyart composite grid pts in x dir given US e/w domain distance and composite res
j=int((US_ns_meters/ compositeres) + 1) # pyart composite grid pts in y dir given US n/s domain distance and composite res
z=4500					# max height above ground in m to use from vol scans - for reference the highest 88d site is SLC at 3231 m -- so 1000m of the ground would be 4231 m there
vlevels=int(z/500)+1                  	# vlevels to sift through to make sure we arent getting cones of silence 500m was arbitrary lyr thickness 


#### generate composite -- these settings and more found in the SOURCE CODE COMMENTS here https://arm-doe.github.io/pyart/_modules/pyart/map/grid_mapper.html#grid_from_radars
pyart_compositegrid = pyart.map.grid_from_radars(  									# ingest master array of available radar arrays 
	(master_array_radsites), 											# array of radar file contents
	grid_shape=(vlevels, j, i),                                              					# desired number of pts in z, y, x for use domain
	grid_origin=(US_centlat,US_centlon),										# grid origin
	grid_limits=((1,z), (-1000*US_ns_meters/2, 1000*US_ns_meters/2), (- 1000*US_ew_meters/2, 1000*US_ew_meters/2)),	# relative to grid orgin, value is meters of the cube of the composite -- 1 to 4000m z, x -15000x
	gatefilters=(list(decluttereddbz.values())),                                                                    # insert here the decluttered object from above 
	min_radius=2000.0,                                                                                              # smaller than 1750 gives you holes, 2000 seemed best visually 
	weighting_function='barnes',                                                                                    # barnes faster, seemed better than cressman -- waaay better than nearest. cressman slowest
	)
# when debugging its easier to write this to ncf file then on the command line read it in import pyart; import numpy as np; pyart_compositegrid=pyart.io.read_grid("pyart.ncf")  and start using the data on command line
#pyart.io.write_grid("pyart1km.ncf",pyart_compositegrid); exit()

# once in ncf format you can open up a command line and then do the following
#  from cfgrib.xarray_to_grib import to_grib; import xarray as xr; import pyart; import numpy as np; pyart_compositegrid=pyart.io.read_grid("pyart.ncf"); dbz=np.max(np.nan_to_num(pyart_compositegrid.fields['reflectivity']['data'].data,nan=0.0),axis=0); dbzlon,dbzlat=pyart_compositegrid.get_point_longitude_latitude(level=0);srclon=dbzlon.flatten(); srclat=dbzlat.flatten(); srcdat=dbz.flatten();pyartarray=xr.DataArray(srcdat,dims='values',coords={'latitude':('values',srclat),'longitude':('values',srclon)});dspyart=pyartarray.to_dataset(name="refc")
#  dspyart.refc.attrs={ 'GRIB_gridType': 'mercator','GRIB_numberOfPoints': 569748, 'GRIB_DxInMetres': 5000.0, 'GRIB_DyInMetres': 5000.0, 'GRIB_LaDInDegrees': 20.0, 'GRIB_Nx': 948, 'GRIB_Ny': 601, 'GRIB_latitudeOfFirstGridPointInDegrees': 22.174179120296714, 'GRIB_longitudeOfFirstGridPointInDegrees': 229.85753933907554, 'GRIB_latitudeOfLastGridPointInDegrees':  51.48982568171301,  'GRIB_longitudeOfLastGridPointInDegrees': 294.142606609244,'GRIB_orientationOfTheGridInDegrees': 0.0 }
#   to_grib(dspyart,"pcm.grb2",grib_keys={'edition':2,},no_warn=True)
# -- mercator reference only
#dsmerc=xr.open_dataset('nam_hawaii_mercator.grb2',engine='cfgrib',backend_kwargs={'read_keys':['LaDInDegrees', 'latitudeOfFirstGridPointInDegrees', 'latitudeOfLastGridPointInDegrees', 'longitudeOfFirstGridPointInDegrees', 'longitudeOfLastGridPointInDegrees', 'Nx', 'Ny', 'DxInMetres', 'DyInMetres', 'orientationOfTheGridInDegrees'],'filter_by_keys': {'shortName': 'refc'}})
#mercarray=dsmerc['refc'] # just the var refc
#dsmercnew=mercarray.to_dataset(name="refc") # push that back into a data set
#to_grib(dsmercnew,"dsmercnew.grb2",grib_keys={'edition':2,},no_warn=True)
### ------
#------------------------ pygrib writing 
#print("pygribbing")
#import pygrib
#grbs=pygrib.open('test.grb2') # a field from a grib file in mercator proj -- like HI or Guam grid
#grb=grbs.message(1)
#grbs.close()
##print(grb.keys())
##for x in grb.keys():
##       if (x != "analDate" and x != "validDate"):
##               print (x,grb[x])
#grb.values=dbz
#grb.gridDefinitionTemplateNumber=10
#grb.longitudeOfFirstGridPointInDegrees=360+minlon
#grb.latitudeOfFirstGridPointInDegrees=minlat
#grb.longitudeOfLastGridPointInDegrees=360+maxlon
#grb.latitudeOfLastGridPointInDegrees=maxlat
#grb.Ni=dbz.shape[1]
#grb.Nj=dbz.shape[0]
#grb.DiInMetres=pyart_compositegrid_xres*100
#grb.DjInMetres=pyart_compositegrid_yres*100
#grb.orientationOfTheGridInDegrees=0
#grb.level=0
#grb.dataDate=19500101
#grb.dataTime=0000
#msg=grb.tostring()
#grbout=open("merc.grb2",'wb')
#grbout.write(msg)
#grbout.close()
#print("done pygribbing")
#exit()
#----------------------------------
#

######## isolate dbz lat lon arrays
# ideally at any point we should be using the reflectivity from the radar whose .5 deg elev slice is 1km or less agl -- and if multiple radars qualify, take the max value 
# but for now though we will revert to a poor mans method of just take the highest value at any point 
# this will minimize cones of silence by getting max dbz at each gate fm all vlevels (goes from 3d array to 2d, first axis is vlevels) AND set all nans to NANVAL since that what pyart gridding does 
# tip you can look at each level of pyart_composite.fields['reflectivity']['data'][level] but the .data looks at them all in bulk 
print(str(datetime.now()) + " - isolating dbz data")
dbz=np.max(np.nan_to_num(pyart_compositegrid.fields['reflectivity']['data'].data,nan=NANVAL),axis=0)

######## convert pyart composite grid to standard latlon via scipy prep and execution
print(str(datetime.now()) + " - prepping source data")
dbzlon,dbzlat=pyart_compositegrid.get_point_longitude_latitude(level=0)  # get_point holds the respective lat lon of the composite grid



# make 1 d arrays for scipy -- and pair up/and transpose the lonlats into one array
srclon=dbzlon.flatten(); srclat=dbzlat.flatten(); srcdat=dbz.flatten()
#--- new
#srclonorig=srclon
#srclatorig=srclat
# thin out any data in the arrays by keeping only data valid data points (srcdat > NANVAL)
srclat=srclat[srcdat>NANVAL]
srclon=srclon[srcdat>NANVAL]
srcdat=srcdat[srcdat>NANVAL]
#-- end new
srclonlatpaired=np.array([srclon,srclat]).T


# establish target grid 
print(str(datetime.now()) + " - prepping target lat lon grid")
minlat=np.min(srclat);maxlat=np.max(srclat);minlon=np.min(srclon);maxlon=np.max(srclon)
refl_xcoords = pyart_compositegrid.x['data']; refl_ycoords = pyart_compositegrid.y['data']  # these are the cartesian points of the composite relative to the center of the grid - not lat lons
pyart_compositegrid_xres = (((-1) * refl_xcoords[0]) - ((-1) * refl_xcoords[1])) / 1000  # should be teh same as compres 
pyart_compositegrid_yres = (((-1) * refl_ycoords[0]) - ((-1) * refl_ycoords[1])) / 1000  # because its a cartesian grid pyart produces this should be the same as xres too
targetgridlonlist=np.arange(minlon,maxlon+pyart_compositegrid_xres*1000*degPerMeter,pyart_compositegrid_xres*1000*degPerMeter)
targetgridlatlist=np.arange(minlat,maxlat+pyart_compositegrid_yres*1000*degPerMeter,pyart_compositegrid_yres*1000*degPerMeter)
targetgridlon,targetgridlat=np.meshgrid(targetgridlonlist,targetgridlatlist)
targetgridlonlatpaired=np.array([targetgridlon.flatten(),targetgridlat.flatten()]).T


##new test with xesmf 
#targetgrid2=xesmf.util.grid_2d(-125.0,-65.0,pyart_compositegrid_xres*1000*degPerMeter,19.0,53.0,pyart_compositegrid_yres*1000*degPerMeter)
#dspyart=xr.Dataset({"dBZ":(("y","x"),dbz)},coords={"lon":(("y","x"),dbzlon),"lat":(("y","x"),dbzlat)})
#regridder=xesmf.Regridder(dspyart,targetgrid2,'bilinear')
#latlongrid2=regridder(dspyart,keep_attrs=True)
#dlat=float(latlongrid2.lat.values[1][0])-float(latlongrid2.lat.values[0][0])
#dlon=float(latlongrid2.lon.values[0][1])-float(latlongrid2.lon.values[0][0])
#lonmin=latlongrid2.lon.values.min();lonmax=latlongrid2.lon.values.max()
#latmin=latlongrid2.lat.values.min();latmax=latlongrid2.lat.values.max()
#latlongrid2.attrs={'GRIB_gridType':'regular_ll','GRIB_gridDefinitionDescription':'Latitude/longitude','GRIB_NV':0,'GRIB_localTablesVersion':0,'GRIB_Nx': latlongrid2.dBZ.shape[1],'GRIB_Ny': latlongrid2.dBZ.shape[0] ,'GRIB_iDirectionIncrementInDegrees': dlon, 'GRIB_jDirectionIncrementInDegrees': dlat, 'GRIB_longitudeOfFirstGridPointInDegrees': lonmin, 'GRIB_latitudeOfFirstGridPointInDegrees': latmin, 'GRIB_longitudeOfLastGridPointInDegrees': lonmax, 'GRIB_latitudeOfLastGridPointInDegrees': latmax, 'GRIB_jScansPositively':1}
#GRIBFILE="radcompx.grb2"
#print(str(datetime.now()) + " - writing to grib2 file "+GRIBFILE)
#to_grib(latlongrid2,GRIBFILE,grib_keys={'edition':2},no_warn=True)
#print(str(datetime.now())+ " - done")
#exit()

### new test with tims method
#import time
#from scipy import interpolate
#Rearth = 6378137
#degPerMeter = 360 / (2 * np.pi * Rearth)
#dtheta = 1 * 1000 * degPerMeter
# Load and filter input data
#ds = xr.open_dataarray("./pyart2.ncf")
#z = ds.values
#input_lats = ds.latitude.values
#input_lons = ds.longitude.values
##z=srcdat; input_lats=srclat; input_lons=srclon
#input_lons = input_lons[z > -9999]
#input_lats = input_lats[z > -9999]
#z = z[z > -9999]
#input_coords = np.column_stack([input_lons, input_lats])
# Make output grid
#lon_edges = np.arange(input_lons.min(), input_lons.max(), dtheta)
#lat_edges = np.arange(input_lats.min(), input_lats.max(), dtheta)
#lon_edges = lon_edges + dtheta / 2
#lat_edges = lat_edges + dtheta / 2
#nlon = len(lon_edges)
#nlat = len(lat_edges)
#lat, lon = np.meshgrid(lat_edges, lon_edges)
#lat, lon = lat.ravel(), lon.ravel()
# do interp
#start_time = time.time()
#grid_z = interpolate.griddata(input_coords, z, (lon, lat), method="linear")
#grid_z = grid_z.reshape(nlon, nlat).T
#end_time = time.time()
#print(f"Interpolation took {end_time - start_time:.2f} seconds")
#exit()


### scipy remapping --- this is slow
print(str(datetime.now()) + " - mapping source data to target lat lon grid - at 1km hres can take 4 min")
latlongrid=griddata(srclonlatpaired,srcdat,targetgridlonlatpaired,method="linear")

# reshape for writing
latlongrid_reshaped=latlongrid.reshape(targetgridlon.shape)

#latlongrid_reshaped=grid_z

#### write to grib2
GRIBFILE="radcomp.grb2"
print(str(datetime.now()) + " - writing to grib2 file "+GRIBFILE)
radcomp=xr.DataArray(data=latlongrid_reshaped,dims=['latitude','longitude'],coords={'latitude':targetgridlatlist,'longitude':targetgridlonlist},)
xwrite=radcomp.to_dataset(name="result") # name is arbitrary, literally doesnt matter
to_grib(xwrite,GRIBFILE,grib_keys={'edition':2},no_warn=True)


### exit
print(str(datetime.now())+ " - done")
exit()
