1// nolint: govet
2package main
3
4import (
5	"encoding/json"
6	"fmt"
7	"math"
8	"os"
9	"strings"
10
11	"github.com/alecthomas/kong"
12	"github.com/alecthomas/participle"
13)
14
15var cli struct {
16	AST        bool               `help:"Print AST for expression."`
17	Set        map[string]float64 `short:"s" help:"Set variables."`
18	Expression []string           `arg required help:"Expression to evaluate."`
19}
20
21type Operator int
22
23const (
24	OpMul Operator = iota
25	OpDiv
26	OpAdd
27	OpSub
28)
29
30var operatorMap = map[string]Operator{"+": OpAdd, "-": OpSub, "*": OpMul, "/": OpDiv}
31
32func (o *Operator) Capture(s []string) error {
33	*o = operatorMap[s[0]]
34	return nil
35}
36
37// E --> T {( "+" | "-" ) T}
38// T --> F {( "*" | "/" ) F}
39// F --> P ["^" F]
40// P --> v | "(" E ")" | "-" T
41
42type Value struct {
43	Number        *float64    `  @(Float|Int)`
44	Variable      *string     `| @Ident`
45	Subexpression *Expression `| "(" @@ ")"`
46}
47
48type Factor struct {
49	Base     *Value `@@`
50	Exponent *Value `[ "^" @@ ]`
51}
52
53type OpFactor struct {
54	Operator Operator `@("*" | "/")`
55	Factor   *Factor  `@@`
56}
57
58type Term struct {
59	Left  *Factor     `@@`
60	Right []*OpFactor `{ @@ }`
61}
62
63type OpTerm struct {
64	Operator Operator `@("+" | "-")`
65	Term     *Term    `@@`
66}
67
68type Expression struct {
69	Left  *Term     `@@`
70	Right []*OpTerm `{ @@ }`
71}
72
73// Display
74
75func (o Operator) String() string {
76	switch o {
77	case OpMul:
78		return "*"
79	case OpDiv:
80		return "/"
81	case OpSub:
82		return "-"
83	case OpAdd:
84		return "+"
85	}
86	panic("unsupported operator")
87}
88
89func (v *Value) String() string {
90	if v.Number != nil {
91		return fmt.Sprintf("%g", *v.Number)
92	}
93	if v.Variable != nil {
94		return *v.Variable
95	}
96	return "(" + v.Subexpression.String() + ")"
97}
98
99func (f *Factor) String() string {
100	out := f.Base.String()
101	if f.Exponent != nil {
102		out += " ^ " + f.Exponent.String()
103	}
104	return out
105}
106
107func (o *OpFactor) String() string {
108	return fmt.Sprintf("%s %s", o.Operator, o.Factor)
109}
110
111func (t *Term) String() string {
112	out := []string{t.Left.String()}
113	for _, r := range t.Right {
114		out = append(out, r.String())
115	}
116	return strings.Join(out, " ")
117}
118
119func (o *OpTerm) String() string {
120	return fmt.Sprintf("%s %s", o.Operator, o.Term)
121}
122
123func (e *Expression) String() string {
124	out := []string{e.Left.String()}
125	for _, r := range e.Right {
126		out = append(out, r.String())
127	}
128	return strings.Join(out, " ")
129}
130
131// Evaluation
132
133func (o Operator) Eval(l, r float64) float64 {
134	switch o {
135	case OpMul:
136		return l * r
137	case OpDiv:
138		return l / r
139	case OpAdd:
140		return l + r
141	case OpSub:
142		return l - r
143	}
144	panic("unsupported operator")
145}
146
147func (v *Value) Eval(ctx Context) float64 {
148	switch {
149	case v.Number != nil:
150		return *v.Number
151	case v.Variable != nil:
152		value, ok := ctx[*v.Variable]
153		if !ok {
154			panic("no such variable " + *v.Variable)
155		}
156		return value
157	default:
158		return v.Subexpression.Eval(ctx)
159	}
160}
161
162func (f *Factor) Eval(ctx Context) float64 {
163	b := f.Base.Eval(ctx)
164	if f.Exponent != nil {
165		return math.Pow(b, f.Exponent.Eval(ctx))
166	}
167	return b
168}
169
170func (t *Term) Eval(ctx Context) float64 {
171	n := t.Left.Eval(ctx)
172	for _, r := range t.Right {
173		n = r.Operator.Eval(n, r.Factor.Eval(ctx))
174	}
175	return n
176}
177
178func (e *Expression) Eval(ctx Context) float64 {
179	l := e.Left.Eval(ctx)
180	for _, r := range e.Right {
181		l = r.Operator.Eval(l, r.Term.Eval(ctx))
182	}
183	return l
184}
185
186type Context map[string]float64
187
188func main() {
189	ctx := kong.Parse(&cli,
190		kong.Description("A basic expression parser and evaluator."),
191		kong.UsageOnError(),
192	)
193
194	parser, err := participle.Build(&Expression{})
195	ctx.FatalIfErrorf(err)
196
197	expr := &Expression{}
198	err = parser.ParseString(strings.Join(cli.Expression, " "), expr)
199	ctx.FatalIfErrorf(err)
200
201	if cli.AST {
202		json.NewEncoder(os.Stdout).Encode(expr)
203	} else {
204		fmt.Println(expr, "=", expr.Eval(cli.Set))
205	}
206}
207