from snakebids import bids, generate_inputs
from pathlib import Path

def path_placeholderstring():
	# "{datadir}/{device}/{device}-{session}/{filename}"
	return "{sub}/{session}/{modality}/{sub}-{session}-{task}"

def generate_input_path(ext):
	return path_placeholderstring() + ext

def generate_pipe_path(step_name, ext):
	return str(Path('preproc-results') / step_name / path_placeholderstring()) + ext

rule all:
	input:
		interp=generate_pipe_path("interpolate", ".zarr"),
		featuremap=generate_pipe_path("badlabel", ".featuremap.png"),
		pairplot=generate_pipe_path("badlabel", ".pairplot.png")
	output:
		touch(generate_pipe_path("all", ".all"))

rule raw_zarr:
	"""Converts the input dat signal to zarr format."""
	input:
		dat=generate_input_path(".lfp"),
		xml=generate_input_path(".xml")
	output:
		zarr=directory(generate_pipe_path("raw_zarr", ".zarr"))
	threads: 4
	run:
		from cogpy.io import ecog_io
		sigx = ecog_io.from_file(input.dat, input.xml)
		sigx.name = "sigx"
		ecog_io.to_zarr(output.zarr, sigx)

rule lowpass:
	"""Applies a lowpass filter to the input dat signal."""
	params: 
		cutoff=config['prep']['cutoff_lp'],   # Cutoff frequency in Hz
		order=config['prep']['order'],      # Order of the filter
		btype='lowpass'
	input:
		raw=ancient(generate_pipe_path("denoised", ".zarr"))
	output:
		filtered=directory(generate_pipe_path("lowpass", ".zarr"))
	script:
		"scripts/01_filter.py"

rule downsample:
	"""Downsamples the input dat signal."""
	params: 
		factor=config['prep']['factor']  # Downsampling factor
	input:
		# raw="lowpass/{session}.zarr"
		raw=generate_pipe_path("lowpass", ".zarr")
	output:
		downsmp=directory(generate_pipe_path("downsample", ".zarr"))
	run:
		import xarray as xr
		from cogpy.io import ecog_io
		sigx = ecog_io.from_zarr(input.raw)['sigx']
		sigx_downsmp = sigx.isel(time=slice(None, None, params.factor))
		sigx_downsmp.name = "sigx"
		sigx_downsmp.attrs['fs'] = sigx.attrs['fs'] / params.factor
		assert sigx_downsmp.attrs['fs'] > config['prep']['cutoff_lp'] * 2.5, \
			"Downsampled signal's fs is too low, " + \
			"reduce the downsampling factor or increase " + \
			"the lowpass cutoff frequency."
		ecog_io.to_zarr(output.downsmp, sigx_downsmp)

rule feature:
	"""Extracts features from the lowpass filtered dat signal."""
	params:
		window_size=config['bad']['window_size'],  # in samples
		window_step=config['bad']['window_step'],   # in samples
		zscore=True 	  # do not change
	input:
		lowpass=generate_pipe_path("downsample", ".zarr")
	output:
		feature=directory(generate_pipe_path("feature", ".zarr"))
	script:
		"scripts/02_feature.py"

rule badlabel:
	"""Scores the quality of the features."""
	input:
		feature=generate_pipe_path("feature", ".zarr")
	output:
		badlabel=generate_pipe_path("badlabel", ".npy")
	params:
		knn=config['bad']['knn'],  # number of neighbors
	script:
		"scripts/03_badlabel.py"

rule plot_feature_maps:
	"""Plots the feature maps with the detected bad channels."""
	input:
		feature=generate_pipe_path("feature", ".zarr"),
		badlabel=generate_pipe_path("badlabel", ".npy")
	output:
		featuremap=generate_pipe_path("badlabel", ".featuremap.png"),
		pairplot=generate_pipe_path("badlabel", ".pairplot.png")
	script:
		"scripts/plot_feature_maps.py"

rule interpolate:
	"""Interpolates the bad labels to create a smooth label signal."""
	input:
		raw=generate_pipe_path("raw_zarr", ".zarr"),
		badlabel=generate_pipe_path("badlabel", ".npy")
	output:
		interp=directory(generate_pipe_path("interpolate", ".zarr"))
	# script:
		# "scripts/04_interpolate.py"
	shell:
		"matlab -batch /path/to/scritpt[functionname in lab toolbox] (input file) (output file)"

