#!/usr/bin/python3
#
# plots halsampler output
#
# Usage errplt file col1 col3 ... coln
#     assumes the columns of data have a label in row 1
#
# x-pos-cmd x-vel-cmd x-pid-out x-pos-fb x-vel-fb
# 0.000000 0.000000 -0.012800 0.005709 -0.011239
# 0.000000 0.000000 -0.012800 0.005709 -0.011239
# 0.000000 0.000000 -0.012800 0.005709 -0.011239
#
# The colx arguments pick off columns to plot from the file
# example:
#
# echo 'x-pos-cmd x-vel-cmd x-pid-out x-pos-fb x-vel-fb' > run.out
# halsampler -c 0 >> run.out
#
# errplt run.out x-pos-cmd x-pos-fb
#
# Copyright 2021 Robert Bond
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
#
# example
#  python3 plot_hs_xy.py -v -f x-error:1000 -f x-pos-cmd^1? -f x-pos-fb? -f y-error:1000 -f y-pos-cmd^1#x-pos-cmd -f y-pos-fb#x-pos-fb -f z-error -f z-pos-cmd -f z-pos-fb -f a-error -f a-pos-cmd -f a-pos-fb test.out
#
# example of x-y plot delaying pos cmds 1 sample so they align with pid error
#  python3 plot_hs_xy.py -v -f x-pos-cmd^1? -f x-pos-fb? -f y-pos-cmd^1#x-pos-cmd -f y-pos-fb#x-pos-fb test.out
#
# example plotting xyza with cmds delayed one sample and errors scaled by 1000 so easy reading error with cmds and feedbacks
#  python3 plot_hs_xy.py -v -f x-error:1000 -f x-pos-cmd^1 -f x-pos-fb -f y-error:1000 -f y-pos-cmd^1 -f y-pos-fb -f z-error:1000 -f z-pos-cmd -f z-pos-fb -f a-error:1000 -f a-pos-cmd -f a-pos-fb test.out


import sys
import argparse
import matplotlib.pyplot as plt
import numpy as np
import re

def process_file(f, fields, scale, SmpDly, xval, hide):
    n = 0
    lines = f.readlines()
    if len(lines) < 2:
        print("Less than 2 lines in file.")
        print("File example:")
        print("x-pos-cmd x-vel-cmd x-pos.out")
        print("0.000000 0.000000 -0.012800")
        exit(1)
    labels = lines[0].strip().split()
    print (labels)

    for field in fields:
        if field not in labels:
            print(field, "not in file header")
            exit(1)

    fIdx = {}
    for i, l in enumerate(labels):
        if (l in fields):
            fIdx[l] = i
    #print(fIdx.keys())

    plot_labels = fields

    plot_data = []
    for j, line in enumerate(lines[1:]):
        lf = line.strip().split()
        if lf[0].isalpha():
            print("skipping", lf[0])
            continue
        lineData = []
        for field in fields:
            lfx = [float(j), float(lf[fIdx[field]]) * scale[field]]
            lineData.append(lfx)
        plot_data.append(lineData)

    plot_data = np.array(plot_data)

    for i in range(len(plot_data[0])):
        s = SmpDly[plot_labels[i]]
        nl = np.roll(plot_data[:, i, 1], s)
        nl[:s] = [nl[s] for item in nl[:s]]
        plot_data[:, i, 1] = nl
        if xval[plot_labels[i]]:
            xindex = plot_labels.index(xval[plot_labels[i]])
            plot_data[: ,i, 0] = plot_data[: ,xindex, 1]
    DelList = []
    for i in range(len(plot_data[0])):
        if hide[plot_labels[i]]:
           DelList.append(i)
    plot_data = np.delete(plot_data, DelList, axis=1)
    plot_labels = [elem for elem in plot_labels if not hide[elem]]

    return plot_labels, plot_data


class plotter(object):
    def __init__(self, nplots):
        if nplots == 1:
            fig, axs = plt.subplots()
            self.fig = [fig]
            self.axs = [axs]
        else:
            self.fig, self.axs = plt.subplots(nplots)
        self.pn = 0

    def plot_subplot(self, labels, data, file_name):
        for i, l in enumerate(labels):
            self.axs[self.pn].plot(data[:, i, 0], data[:, i, 1], label=l)
        self.axs[self.pn].grid(b=True, which='major', axis='x')
        self.axs[self.pn].set_title(file_name)
        self.axs[self.pn].legend()
        self.pn += 1

    def show(self):
        plt.show()

if __name__ == "__main__":
    ap = argparse.ArgumentParser(description='Get log data')
    ap.add_argument('--verbose', '-v', action='count', default=0,
                    help='Print debug info')
    ap.add_argument('--field', '-f', action='append',
                    help='Field name, optional (:scale), optional (^sampleDelay), optional(?) dont show, optional(#fieldName) use Xvalues from another field')
    ap.add_argument('files', nargs='+', help='Files to process')
    args = ap.parse_args()

    xval = {}
    hide = {}
    SmpDly = {}
    scale = {}
    fields = []
    if args.field is not None:
        for f in args.field:
            fs = re.split("(:|\^|#|\?)", f)
            print(fs)
            fields.append(fs[0])
            scale[fs[0]] = 1.0
            SmpDly[fs[0]] = 0
            xval[fs[0]] = ""
            hide[fs[0]] = False
            for i in range(len(fs)):
                if ':' in fs[i]:
                    scale[fs[0]] = float(fs[i+1])

                if '^' in fs[i]:
                    SmpDly[fs[0]] = int(fs[i+1])

                if '#' in fs[i]:
                    xval[fs[0]] = fs[i+1]

                if '?' in fs[i]:
                    hide[fs[0]] = True

    if args.verbose > 0:
        print("fields:", fields)
        print("scale:", scale)
        print("SmpDly:", SmpDly)
        print("xval:", xval)

    plot = plotter(len(args.files))
    for fname in args.files:
        try:
            f = open(fname, 'r')
            labels, data = process_file(f, fields, scale, SmpDly, xval, hide)
            plot.plot_subplot(labels, data, fname)
            f.close()
        except IOError as e:
            print("Can't open file", fname)
            exit(1)
    plot.show()
