from osgeo import gdal
import numpy as np
fn_img = r"..\data\input\m_4311515_ne_11_1_20150908_20160104.jp2"
fn_ndvi = "../data/output/ndvi.tif"
ds_in = gdal.Open(fn_img)
driver = gdal.GetDriverByName('GTiff')
r = ds_in.GetRasterBand(1).ReadAsArray().astype(float)
n = ds_in.GetRasterBand(4).ReadAsArray().astype(float)
ndvi = (n - r) / (n + r)
nodata = -9999.0
ndvi = np.where((ndvi < -1.0) | (ndvi > 1.0), nodata, ndvi)
ds_out = driver.Create(fn_ndvi, xsize=ndvi.shape[1], ysize=ndvi.shape[0], bands=1, eType=gdal.GDT_Float32)
ds_out.SetGeoTransform(ds_in.GetGeoTransform())
ds_out.SetProjection(ds_in.GetProjection())
ds_out.GetRasterBand(1).WriteArray(ndvi)
ds_out.GetRasterBand(1).SetNoDataValue(nodata)
ds_in = None
ds_out = None