Branch data Line data Source code
1 : : // This file is part of Eigen, a lightweight C++ template library
2 : : // for linear algebra.
3 : : //
4 : : // Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
5 : : //
6 : : // This Source Code Form is subject to the terms of the Mozilla
7 : : // Public License v. 2.0. If a copy of the MPL was not distributed
8 : : // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 : :
10 : : #ifndef EIGEN_SYMBOLIC_INDEX_H
11 : : #define EIGEN_SYMBOLIC_INDEX_H
12 : :
13 : : namespace Eigen {
14 : :
15 : : /** \namespace Eigen::symbolic
16 : : * \ingroup Core_Module
17 : : *
18 : : * This namespace defines a set of classes and functions to build and evaluate symbolic expressions of scalar type Index.
19 : : * Here is a simple example:
20 : : *
21 : : * \code
22 : : * // First step, defines symbols:
23 : : * struct x_tag {}; static const symbolic::SymbolExpr<x_tag> x;
24 : : * struct y_tag {}; static const symbolic::SymbolExpr<y_tag> y;
25 : : * struct z_tag {}; static const symbolic::SymbolExpr<z_tag> z;
26 : : *
27 : : * // Defines an expression:
28 : : * auto expr = (x+3)/y+z;
29 : : *
30 : : * // And evaluate it: (c++14)
31 : : * std::cout << expr.eval(x=6,y=3,z=-13) << "\n";
32 : : *
33 : : * // In c++98/11, only one symbol per expression is supported for now:
34 : : * auto expr98 = (3-x)/2;
35 : : * std::cout << expr98.eval(x=6) << "\n";
36 : : * \endcode
37 : : *
38 : : * It is currently only used internally to define and manipulate the Eigen::last and Eigen::lastp1 symbols in Eigen::seq and Eigen::seqN.
39 : : *
40 : : */
41 : : namespace symbolic {
42 : :
43 : : template<typename Tag> class Symbol;
44 : : template<typename Arg0> class NegateExpr;
45 : : template<typename Arg1,typename Arg2> class AddExpr;
46 : : template<typename Arg1,typename Arg2> class ProductExpr;
47 : : template<typename Arg1,typename Arg2> class QuotientExpr;
48 : :
49 : : // A simple wrapper around an integral value to provide the eval method.
50 : : // We could also use a free-function symbolic_eval...
51 : : template<typename IndexType=Index>
52 : : class ValueExpr {
53 : : public:
54 : : ValueExpr(IndexType val) : m_value(val) {}
55 : : template<typename T>
56 : : IndexType eval_impl(const T&) const { return m_value; }
57 : : protected:
58 : : IndexType m_value;
59 : : };
60 : :
61 : : // Specialization for compile-time value,
62 : : // It is similar to ValueExpr(N) but this version helps the compiler to generate better code.
63 : : template<int N>
64 : : class ValueExpr<internal::FixedInt<N> > {
65 : : public:
66 : 4 : ValueExpr() {}
67 : : template<typename T>
68 : : EIGEN_CONSTEXPR Index eval_impl(const T&) const { return N; }
69 : : };
70 : :
71 : :
72 : : /** \class BaseExpr
73 : : * \ingroup Core_Module
74 : : * Common base class of any symbolic expressions
75 : : */
76 : : template<typename Derived>
77 : : class BaseExpr
78 : : {
79 : : public:
80 : 4 : const Derived& derived() const { return *static_cast<const Derived*>(this); }
81 : :
82 : : /** Evaluate the expression given the \a values of the symbols.
83 : : *
84 : : * \param values defines the values of the symbols, it can either be a SymbolValue or a std::tuple of SymbolValue
85 : : * as constructed by SymbolExpr::operator= operator.
86 : : *
87 : : */
88 : : template<typename T>
89 : : Index eval(const T& values) const { return derived().eval_impl(values); }
90 : :
91 : : #if EIGEN_HAS_CXX14
92 : : template<typename... Types>
93 : : Index eval(Types&&... values) const { return derived().eval_impl(std::make_tuple(values...)); }
94 : : #endif
95 : :
96 : : NegateExpr<Derived> operator-() const { return NegateExpr<Derived>(derived()); }
97 : :
98 : : AddExpr<Derived,ValueExpr<> > operator+(Index b) const
99 : : { return AddExpr<Derived,ValueExpr<> >(derived(), b); }
100 : : AddExpr<Derived,ValueExpr<> > operator-(Index a) const
101 : : { return AddExpr<Derived,ValueExpr<> >(derived(), -a); }
102 : : ProductExpr<Derived,ValueExpr<> > operator*(Index a) const
103 : : { return ProductExpr<Derived,ValueExpr<> >(derived(),a); }
104 : : QuotientExpr<Derived,ValueExpr<> > operator/(Index a) const
105 : : { return QuotientExpr<Derived,ValueExpr<> >(derived(),a); }
106 : :
107 : : friend AddExpr<Derived,ValueExpr<> > operator+(Index a, const BaseExpr& b)
108 : : { return AddExpr<Derived,ValueExpr<> >(b.derived(), a); }
109 : : friend AddExpr<NegateExpr<Derived>,ValueExpr<> > operator-(Index a, const BaseExpr& b)
110 : : { return AddExpr<NegateExpr<Derived>,ValueExpr<> >(-b.derived(), a); }
111 : : friend ProductExpr<ValueExpr<>,Derived> operator*(Index a, const BaseExpr& b)
112 : : { return ProductExpr<ValueExpr<>,Derived>(a,b.derived()); }
113 : : friend QuotientExpr<ValueExpr<>,Derived> operator/(Index a, const BaseExpr& b)
114 : : { return QuotientExpr<ValueExpr<>,Derived>(a,b.derived()); }
115 : :
116 : : template<int N>
117 : 4 : AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>) const
118 : 4 : { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
119 : : template<int N>
120 : : AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N>) const
121 : : { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
122 : : template<int N>
123 : : ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N>) const
124 : : { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
125 : : template<int N>
126 : : QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N>) const
127 : : { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
128 : :
129 : : template<int N>
130 : : friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N>, const BaseExpr& b)
131 : : { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
132 : : template<int N>
133 : : friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N>, const BaseExpr& b)
134 : : { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
135 : : template<int N>
136 : : friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N>, const BaseExpr& b)
137 : : { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
138 : : template<int N>
139 : : friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N>, const BaseExpr& b)
140 : : { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
141 : :
142 : : #if (!EIGEN_HAS_CXX14)
143 : : template<int N>
144 : : AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)()) const
145 : : { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >()); }
146 : : template<int N>
147 : : AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > > operator-(internal::FixedInt<N> (*)()) const
148 : : { return AddExpr<Derived,ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >()); }
149 : : template<int N>
150 : : ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator*(internal::FixedInt<N> (*)()) const
151 : : { return ProductExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
152 : : template<int N>
153 : : QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator/(internal::FixedInt<N> (*)()) const
154 : : { return QuotientExpr<Derived,ValueExpr<internal::FixedInt<N> > >(derived(),ValueExpr<internal::FixedInt<N> >()); }
155 : :
156 : : template<int N>
157 : : friend AddExpr<Derived,ValueExpr<internal::FixedInt<N> > > operator+(internal::FixedInt<N> (*)(), const BaseExpr& b)
158 : : { return AddExpr<Derived,ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >()); }
159 : : template<int N>
160 : : friend AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > > operator-(internal::FixedInt<N> (*)(), const BaseExpr& b)
161 : : { return AddExpr<NegateExpr<Derived>,ValueExpr<internal::FixedInt<N> > >(-b.derived(), ValueExpr<internal::FixedInt<N> >()); }
162 : : template<int N>
163 : : friend ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator*(internal::FixedInt<N> (*)(), const BaseExpr& b)
164 : : { return ProductExpr<ValueExpr<internal::FixedInt<N> >,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
165 : : template<int N>
166 : : friend QuotientExpr<ValueExpr<internal::FixedInt<N> >,Derived> operator/(internal::FixedInt<N> (*)(), const BaseExpr& b)
167 : : { return QuotientExpr<ValueExpr<internal::FixedInt<N> > ,Derived>(ValueExpr<internal::FixedInt<N> >(),b.derived()); }
168 : : #endif
169 : :
170 : :
171 : : template<typename OtherDerived>
172 : : AddExpr<Derived,OtherDerived> operator+(const BaseExpr<OtherDerived> &b) const
173 : : { return AddExpr<Derived,OtherDerived>(derived(), b.derived()); }
174 : :
175 : : template<typename OtherDerived>
176 : : AddExpr<Derived,NegateExpr<OtherDerived> > operator-(const BaseExpr<OtherDerived> &b) const
177 : : { return AddExpr<Derived,NegateExpr<OtherDerived> >(derived(), -b.derived()); }
178 : :
179 : : template<typename OtherDerived>
180 : : ProductExpr<Derived,OtherDerived> operator*(const BaseExpr<OtherDerived> &b) const
181 : : { return ProductExpr<Derived,OtherDerived>(derived(), b.derived()); }
182 : :
183 : : template<typename OtherDerived>
184 : : QuotientExpr<Derived,OtherDerived> operator/(const BaseExpr<OtherDerived> &b) const
185 : : { return QuotientExpr<Derived,OtherDerived>(derived(), b.derived()); }
186 : : };
187 : :
188 : : template<typename T>
189 : : struct is_symbolic {
190 : : // BaseExpr has no conversion ctor, so we only have to check whether T can be statically cast to its base class BaseExpr<T>.
191 : : enum { value = internal::is_convertible<T,BaseExpr<T> >::value };
192 : : };
193 : :
194 : : /** Represents the actual value of a symbol identified by its tag
195 : : *
196 : : * It is the return type of SymbolValue::operator=, and most of the time this is only way it is used.
197 : : */
198 : : template<typename Tag>
199 : : class SymbolValue
200 : : {
201 : : public:
202 : : /** Default constructor from the value \a val */
203 : : SymbolValue(Index val) : m_value(val) {}
204 : :
205 : : /** \returns the stored value of the symbol */
206 : : Index value() const { return m_value; }
207 : : protected:
208 : : Index m_value;
209 : : };
210 : :
211 : : /** Expression of a symbol uniquely identified by the template parameter type \c tag */
212 : : template<typename tag>
213 : : class SymbolExpr : public BaseExpr<SymbolExpr<tag> >
214 : : {
215 : : public:
216 : : /** Alias to the template parameter \c tag */
217 : : typedef tag Tag;
218 : :
219 : 4 : SymbolExpr() {}
220 : :
221 : : /** Associate the value \a val to the given symbol \c *this, uniquely identified by its \c Tag.
222 : : *
223 : : * The returned object should be passed to ExprBase::eval() to evaluate a given expression with this specified runtime-time value.
224 : : */
225 : : SymbolValue<Tag> operator=(Index val) const {
226 : : return SymbolValue<Tag>(val);
227 : : }
228 : :
229 : : Index eval_impl(const SymbolValue<Tag> &values) const { return values.value(); }
230 : :
231 : : #if EIGEN_HAS_CXX14
232 : : // C++14 versions suitable for multiple symbols
233 : : template<typename... Types>
234 : : Index eval_impl(const std::tuple<Types...>& values) const { return std::get<SymbolValue<Tag> >(values).value(); }
235 : : #endif
236 : : };
237 : :
238 : : template<typename Arg0>
239 : : class NegateExpr : public BaseExpr<NegateExpr<Arg0> >
240 : : {
241 : : public:
242 : : NegateExpr(const Arg0& arg0) : m_arg0(arg0) {}
243 : :
244 : : template<typename T>
245 : : Index eval_impl(const T& values) const { return -m_arg0.eval_impl(values); }
246 : : protected:
247 : : Arg0 m_arg0;
248 : : };
249 : :
250 : : template<typename Arg0, typename Arg1>
251 : : class AddExpr : public BaseExpr<AddExpr<Arg0,Arg1> >
252 : : {
253 : : public:
254 : 4 : AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
255 : :
256 : : template<typename T>
257 : : Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) + m_arg1.eval_impl(values); }
258 : : protected:
259 : : Arg0 m_arg0;
260 : : Arg1 m_arg1;
261 : : };
262 : :
263 : : template<typename Arg0, typename Arg1>
264 : : class ProductExpr : public BaseExpr<ProductExpr<Arg0,Arg1> >
265 : : {
266 : : public:
267 : : ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
268 : :
269 : : template<typename T>
270 : : Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) * m_arg1.eval_impl(values); }
271 : : protected:
272 : : Arg0 m_arg0;
273 : : Arg1 m_arg1;
274 : : };
275 : :
276 : : template<typename Arg0, typename Arg1>
277 : : class QuotientExpr : public BaseExpr<QuotientExpr<Arg0,Arg1> >
278 : : {
279 : : public:
280 : : QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
281 : :
282 : : template<typename T>
283 : : Index eval_impl(const T& values) const { return m_arg0.eval_impl(values) / m_arg1.eval_impl(values); }
284 : : protected:
285 : : Arg0 m_arg0;
286 : : Arg1 m_arg1;
287 : : };
288 : :
289 : : } // end namespace symbolic
290 : :
291 : : } // end namespace Eigen
292 : :
293 : : #endif // EIGEN_SYMBOLIC_INDEX_H
|