JAVA 反射+注释,根据实体类对象生成SQL语句工具类

13年前
由于觉得配置Hibernate过于繁琐,索性使用了spring的jdbc,可是又要写很多的sql语句,为了偷偷懒,于是就写个能通过实体类对象生成SQL语句的工具类。

目前只在MySql数据库上实验通过,其他数据库未测试。

本工具类还有很多不足之处,不过好在可以满足自己一些简单的日常使用。

上代码了。

字段类型:

1 package net.tjnwdseip.util;
2  
3 public enum FieldType {
4  
5     STRING,NUMBER,DATE
6 }

字段注释:
01 package net.tjnwdseip.util;
02  
03 import java.lang.annotation.Documented;
04 import java.lang.annotation.ElementType;
05 import java.lang.annotation.Retention;
06 import java.lang.annotation.RetentionPolicy;
07 import java.lang.annotation.Target;
08  
09 @Documented
10 @Retention(RetentionPolicy.RUNTIME)
11 @Target(ElementType.FIELD)
12 public <a href="http://my.oschina.net/interface" class="referer" target="_blank">@interface</a>  FieldAnnotation {
13  
14     String fieldName();
15      
16     FieldType fieldType();
17      
18     boolean pk();
19 }

表名注释:
01 package net.tjnwdseip.util;
02  
03 import java.lang.annotation.Documented;
04 import java.lang.annotation.ElementType;
05 import java.lang.annotation.Retention;
06 import java.lang.annotation.RetentionPolicy;
07 import java.lang.annotation.Target;
08  
09 @Documented
10 @Retention(RetentionPolicy.RUNTIME)
11 @Target(ElementType.TYPE)
12 public <a href="http://my.oschina.net/interface" class="referer" target="_blank">@interface</a>  TableAnnotation {
13  
14     String tableName();
15 }

SQL语句生成工具类:
001 package net.tjnwdseip.util;
002  
003 import java.lang.reflect.Field;
004 import java.lang.reflect.InvocationTargetException;
005 import java.lang.reflect.Method;
006 import java.util.ArrayList;
007 import java.util.HashMap;
008 import java.util.Iterator;
009 import java.util.List;
010  
011 /**
012  *
013  * @ClassName: CreateSqlTools
014  * @Description: TODO(根据实体类对象生成SQL语句)
015  * <a href="http://my.oschina.net/arthor" class="referer" target="_blank">@author</a>  LiYang
016  * @date 2012-5-4 下午10:07:03
017  *
018  */
019 public class CreateSqlTools {
020  
021     /**
022      *
023      * @Title: getTableName
024      * @Description: TODO(获取表名)
025      * @param @param obj
026      * @param @return 设定文件
027      * @return String 返回类型
028      * @throws
029      */
030     private static String getTableName(Object obj) {
031         String tableName = null;
032         if (obj.getClass().isAnnotationPresent(TableAnnotation.class)) {
033             tableName = obj.getClass().getAnnotation(TableAnnotation.class)
034                     .tableName();
035         }
036         return tableName;
037     }
038  
039     /**
040      *
041      * @Title: getAnnoFieldList
042      * @Description: TODO(获取所有有注释的字段,支持多重继承)
043      * @param @param obj
044      * @param @return 设定文件
045      * @return List<Field> 返回类型
046      * @throws
047      */
048     @SuppressWarnings("rawtypes")
049     private static List<Field> getAnnoFieldList(Object obj) {
050         List<Field> list = new ArrayList<Field>();
051         Class superClass = obj.getClass().getSuperclass();
052         while (true) {
053             if (superClass != null) {
054                 Field[] superFields = superClass.getDeclaredFields();
055                 if (superFields != null && superFields.length > 0) {
056                     for (Field field : superFields) {
057                         if (field.isAnnotationPresent(FieldAnnotation.class)) {
058                             list.add(field);
059                         }
060                     }
061                 }
062                 superClass = superClass.getSuperclass();
063             } else {
064                 break;
065             }
066         }
067         Field[] objFields = obj.getClass().getDeclaredFields();
068         if (objFields != null && objFields.length > 0) {
069             for (Field field : objFields) {
070                 if (field.isAnnotationPresent(FieldAnnotation.class)) {
071                     list.add(field);
072                 }
073             }
074         }
075         return list;
076     }
077  
078     /**
079      *
080      * @Title: getFieldValue
081      * @Description: TODO(获取字段的值,支持多重继承)
082      * @param @param obj
083      * @param @param field
084      * @param @return 设定文件
085      * @return String 返回类型
086      * @throws
087      */
088     @SuppressWarnings({ "rawtypes" })
089     private static String getFieldValue(Object obj, Field field) {
090         String value = null;
091         String name = field.getName();
092         String methodName = "get" + name.substring(0, 1).toUpperCase()
093                 + name.substring(1);
094         Method method = null;
095         Object methodValue = null;
096         try {
097             method = obj.getClass().getMethod(methodName);
098         } catch (NoSuchMethodException | SecurityException e1) {
099             // TODO Auto-generated catch block
100         }
101         if (method != null) {
102             try {
103                 methodValue = method.invoke(obj);
104             } catch (IllegalAccessException | IllegalArgumentException
105                     | InvocationTargetException e) {
106                 // TODO Auto-generated catch block
107             }
108             if (methodValue != null) {
109                 value = methodValue.toString();
110             } else {
111                 Class objSuperClass = obj.getClass().getSuperclass();
112                 while (true) {
113                     if (objSuperClass != null) {
114                         try {
115                             methodValue = method.invoke(objSuperClass);
116                         } catch (IllegalAccessException
117                                 | IllegalArgumentException
118                                 | InvocationTargetException e) {
119                             // TODO Auto-generated catch block
120                         }
121                         if (methodValue != null) {
122                             value = methodValue.toString();
123                             break;
124                         } else {
125                             objSuperClass = objSuperClass.getSuperclass();
126                         }
127                     } else {
128                         break;
129                     }
130                 }
131             }
132         }
133         return value;
134     }
135  
136     /**
137      *
138      * @Title: getInsertSql
139      * @Description: TODO(根据实体类对象字段的值生成INSERT SQL语句,可选固定参数)
140      * @param @param obj
141      * @param @param fixedParams
142      *        固定参数(如该参数与实体类中有相同的字段,则忽略实体类中的对应字段,HashMap<String
143      *        ,String>,key=指定字段名,value=对应字段的值)
144      * @param @return 设定文件
145      * @return String 返回类型
146      * @throws
147      */
148     public static String getInsertSql(Object obj,
149             HashMap<String, String> fixedParams) {
150         String insertSql = null;
151         String tableName = getTableName(obj);
152         if (tableName != null) {
153             StringBuffer sqlStr = new StringBuffer("INSERT INTO ");
154             StringBuffer valueStr = new StringBuffer(" VALUES (");
155             List<Field> annoFieldList = getAnnoFieldList(obj);
156             if (annoFieldList != null && annoFieldList.size() > 0) {
157                 sqlStr.append(tableName + " (");
158                 if (fixedParams != null && fixedParams.size() > 0) {
159                     Iterator<String> keyNames = fixedParams.keySet().iterator();
160                     while (keyNames.hasNext()) {
161                         String keyName = (String) keyNames.next();
162                         sqlStr.append(keyName + ",");
163                         valueStr.append(fixedParams.get(keyName) + ",");
164                     }
165                 }
166                 for (Field field : annoFieldList) {
167                     FieldAnnotation anno = field
168                             .getAnnotation(FieldAnnotation.class);
169                     if (!anno.pk()) {
170                         Object fieldValue = getFieldValue(obj, field);
171                         if (fieldValue != null) {
172                             if (fixedParams != null && fixedParams.size() > 0) {
173                                 Iterator<String> keyNames = fixedParams
174                                         .keySet().iterator();
175                                 boolean nextFieldFlag = false;
176                                 while (keyNames.hasNext()) {
177                                     String keyName = (String) keyNames.next();
178                                     if (anno.fieldName().equals(keyName)) {
179                                         nextFieldFlag = true;
180                                         break;
181                                     }
182                                 }
183                                 if (nextFieldFlag) {
184                                     break;
185                                 }
186                             }
187                             sqlStr.append(anno.fieldName() + ",");
188                             switch (anno.fieldType()) {
189                             case NUMBER:
190                                 valueStr.append(fieldValue + ",");
191                                 break;
192                             default:
193                                 valueStr.append("'" + fieldValue + "',");
194                                 break;
195                             }
196                         }
197                     }
198                 }
199                 insertSql = sqlStr.toString().substring(0, sqlStr.length() - 1)
200                         + ")"
201                         + valueStr.toString().substring(0,
202                                 valueStr.length() - 1) + ")";
203             }
204         }
205         return insertSql;
206     }
207  
208     /**
209      *
210      * @Title: getInsertSql
211      * @Description: TODO(根据实体类对象字段的值生成INSERT SQL语句)
212      * @param @param obj
213      * @param @return 设定文件
214      * @return String 返回类型
215      * @throws
216      */
217     public static String getInsertSql(Object obj) {
218         return getInsertSql(obj, null);
219     }
220  
221     /**
222      *
223      * @Title: getUpdateSql
224      * @Description: TODO(根据实体类对象字段的值生成UPDATE SQL语句,可选更新条件为主键,可选固定更新参数)
225      * @param @param obj
226      * @param @param reqPk 是否指定更新条件为主键(true=是,false=否)
227      * @param @param fixedParams
228      *        固定参数(如该参数与实体类中有相同的字段,则忽略实体类中的对应字段,HashMap<String
229      *        ,String>,key=指定字段名,value=对应字段的值)
230      * @param @return 设定文件
231      * @return String 返回类型
232      * @throws
233      */
234     public static String getUpdateSql(Object obj, boolean reqPk,
235             HashMap<String, String> fixedParams) {
236         String updateSql = null;
237         String tableName = getTableName(obj);
238         if (tableName != null) {
239             List<Field> annoFieldList = getAnnoFieldList(obj);
240             if (annoFieldList != null && annoFieldList.size() > 0) {
241                 StringBuffer sqlStr = new StringBuffer("UPDATE " + tableName);
242                 StringBuffer valueStr = new StringBuffer(" SET ");
243                 String whereStr = " WHERE ";
244                 if (fixedParams != null && fixedParams.size() > 0) {
245                     Iterator<String> keyNames = fixedParams.keySet().iterator();
246                     while (keyNames.hasNext()) {
247                         String keyName = (String) keyNames.next();
248                         valueStr.append(keyName + "="
249                                 + fixedParams.get(keyName) + ",");
250                     }
251                 }
252                 for (Field field : annoFieldList) {
253                     String fieldValue = getFieldValue(obj, field);
254                     if (fieldValue != null) {
255                         FieldAnnotation anno = field
256                                 .getAnnotation(FieldAnnotation.class);
257                         if (!anno.pk()) {
258                             if (fixedParams != null && fixedParams.size() > 0) {
259                                 boolean nextFieldFlag = false;
260                                 Iterator<String> keyNames = fixedParams
261                                         .keySet().iterator();
262                                 while (keyNames.hasNext()) {
263                                     String keyName = (String) keyNames.next();
264                                     if (anno.fieldName().equals(keyName)) {
265                                         nextFieldFlag = true;
266                                         break;
267                                     }
268                                 }
269                                 if (nextFieldFlag) {
270                                     break;
271                                 }
272                             }
273                             valueStr.append(anno.fieldName() + "=");
274                             switch (anno.fieldType()) {
275                             case NUMBER:
276                                 valueStr.append(fieldValue + ",");
277                                 break;
278                             default:
279                                 valueStr.append("'" + fieldValue + "',");
280                                 break;
281                             }
282                         } else {
283                             if (reqPk) {
284                                 whereStr += anno.fieldName() + "=" + fieldValue;
285                             }
286                         }
287                     }
288                 }
289                 updateSql = sqlStr.toString()
290                         + valueStr.toString().substring(0,
291                                 valueStr.length() - 1)
292                         + (reqPk ? whereStr : "");
293             }
294         }
295         return updateSql;
296     }
297  
298     /**
299      *
300      * @Title: getUpdateSql
301      * @Description: TODO(根据实体类对象字段的值生成UPDATE SQL语句,无条件)
302      * @param @param obj
303      * @param @return 设定文件
304      * @return String 返回类型
305      * @throws
306      */
307     public static String getUpdateSql(Object obj) {
308         return getUpdateSql(obj, false, null);
309     }
310  
311     /**
312      *
313      * @Title: getUpdateSql
314      * @Description: TODO(根据实体类对象字段的值生成UPDATE SQL语句,可选更新条件为主键)
315      * @param @param obj
316      * @param @param reqPk 是否指定更新条件为主键(true=是,false=否)
317      * @param @return 设定文件
318      * @return String 返回类型
319      * @throws
320      */
321     public static String getUpdateSql(Object obj, boolean reqPk) {
322         return getUpdateSql(obj, reqPk, null);
323     }
324  
325     /**
326      *
327      * @Title: getDeleteSql
328      * @Description: TODO(根据实体类对象字段的值生成有条件的DELETE
329      *               SQL语句,可选主键为删除条件或使用各个字段的值为条件,多个条件用AND连接)
330      * @param @param obj
331      * @param @param reqPk 是否指定更新条件为主键(true=是,false=否)
332      * @param @return 设定文件
333      * @return String 返回类型
334      * @throws
335      */
336     public static String getDeleteSql(Object obj, boolean reqPk) {
337         String deleteSql = null;
338         String tableName = getTableName(obj);
339         if (tableName != null) {
340             StringBuffer delSqlBuffer = new StringBuffer("DELETE FROM ");
341             List<Field> annoFieldList = getAnnoFieldList(obj);
342             if (annoFieldList != null && annoFieldList.size() > 0) {
343                 delSqlBuffer.append(tableName + " WHERE ");
344                 for (Field field : annoFieldList) {
345                     if (reqPk) {
346                         FieldAnnotation anno = field
347                                 .getAnnotation(FieldAnnotation.class);
348                         if (anno.pk()) {
349                             String fieldValue = getFieldValue(obj, field);
350                             delSqlBuffer.append(anno.fieldName() + "=");
351                             switch (anno.fieldType()) {
352                             case NUMBER:
353                                 delSqlBuffer.append(fieldValue);
354                                 break;
355                             default:
356                                 delSqlBuffer.append("'" + fieldValue + "'");
357                                 break;
358                             }
359                             break;
360                         }
361                     } else {
362                         String fieldValue = getFieldValue(obj, field);
363                         if (fieldValue != null) {
364                             FieldAnnotation anno = field
365                                     .getAnnotation(FieldAnnotation.class);
366                             delSqlBuffer.append(anno.fieldName() + "=");
367                             switch (anno.fieldType()) {
368                             case NUMBER:
369                                 delSqlBuffer.append(fieldValue + " AND ");
370                                 break;
371                             default:
372                                 delSqlBuffer
373                                         .append("'" + fieldValue + "' AND ");
374                                 break;
375                             }
376                         }
377                     }
378                 }
379                 if (reqPk) {
380                     deleteSql = delSqlBuffer.toString();
381                 } else {
382                     deleteSql = delSqlBuffer.toString().substring(0,
383                             delSqlBuffer.length() - 5);
384                 }
385             }
386         }
387         return deleteSql;
388     }
389  
390     /**
391      *
392      * @Title: getDeleteSql
393      * @Description: TODO(根据实体类对象字段的值生成有条件的DELETE SQL语句,使用各个字段的值为条件,多个条件用AND连接)
394      * @param @param obj
395      * @param @return 设定文件
396      * @return String 返回类型
397      * @throws
398      */
399     public static String getDeleteSql(Object obj) {
400         return getDeleteSql(obj, false);
401     }
402  
403     /**
404      *
405      * @Title: getSelectAllSql
406      * @Description: TODO(根据实体类对象字段的值生成SELECT SQL语句,无查询条件)
407      * @param @param obj
408      * @param @return 设定文件
409      * @return String 返回类型
410      * @throws
411      */
412     public static String getSelectAllSql(Object obj) {
413         String selectSql = null;
414         String tableName = getTableName(obj);
415         if (tableName != null) {
416             StringBuffer selectBuffer = new StringBuffer("SELECT ");
417             List<Field> annoFieldList = getAnnoFieldList(obj);
418             if (annoFieldList != null && annoFieldList.size() > 0) {
419                 for (Field field : annoFieldList) {
420                     FieldAnnotation anno = field
421                             .getAnnotation(FieldAnnotation.class);
422                     selectBuffer.append(anno.fieldName() + ",");
423                 }
424                 selectSql = selectBuffer.toString().substring(0,
425                         selectBuffer.length() - 1)
426                         + " FROM " + tableName;
427             }
428         }
429         return selectSql;
430     }
431 }

实体类注释写法:
01 package net.tjnwdseip.entity;
02  
03 import java.sql.Timestamp;
04  
05 import net.tjnwdseip.util.FieldAnnotation;
06 import net.tjnwdseip.util.FieldType;
07  
08 public class BaseEntity {
09  
10     @FieldAnnotation(fieldName="id",fieldType=FieldType.NUMBER,pk=true)
11     private Integer id;
12      
13     @FieldAnnotation(fieldName="createDate",fieldType=FieldType.DATE, pk = false)
14     private Timestamp createDate;
15      
16     @FieldAnnotation(fieldName="modifyDate",fieldType=FieldType.DATE, pk = false)
17     private Timestamp modifyDate;
18  
19     public Integer getId() {
20         return id;
21     }
22  
23     public void setId(Integer id) {
24         this.id = id;
25     }
26  
27     public Timestamp getCreateDate() {
28         return createDate;
29     }
30  
31     public void setCreateDate(Timestamp createDate) {
32         this.createDate = createDate;
33     }
34  
35     public Timestamp getModifyDate() {
36         return modifyDate;
37     }
38  
39     public void setModifyDate(Timestamp modifyDate) {
40         this.modifyDate = modifyDate;
41     }
42  
43     public BaseEntity(Integer id, Timestamp createDate, Timestamp modifyDate) {
44         super();
45         this.id = id;
46         this.createDate = createDate;
47         this.modifyDate = modifyDate;
48     }
49  
50     public BaseEntity() {
51         super();
52     }
53 }

01 package net.tjnwdseip.entity;
02  
03 import java.sql.Timestamp;
04  
05 import net.tjnwdseip.util.FieldAnnotation;
06 import net.tjnwdseip.util.FieldType;
07 import net.tjnwdseip.util.TableAnnotation;
08 /**
09  *
10  * @ClassName: SysNetProxyCfg
11  * @Description: TODO(网络代理设置)
12  * <a href="http://my.oschina.net/arthor" class="referer" target="_blank">@author</a>  LiYang
13  * @date 2012-5-2 下午4:13:08
14  *
15  */
16 @TableAnnotation(tableName="sysNetProxyCfg")
17 public class SysNetProxyCfg extends BaseEntity {
18  
19     @FieldAnnotation(fieldName = "name", fieldType = FieldType.STRING, pk = false)
20     private String name;
21      
22     @FieldAnnotation(fieldName = "type", fieldType = FieldType.STRING, pk = false)
23     private String type;
24      
25     @FieldAnnotation(fieldName = "proxyHostIp", fieldType = FieldType.STRING, pk = false)
26     private String proxyHostIp;
27      
28     @FieldAnnotation(fieldName = "proxyPort", fieldType = FieldType.NUMBER, pk = false)
29     private Integer proxyPort;
30  
31     public String getName() {
32         return name;
33     }
34  
35     public void setName(String name) {
36         this.name = name;
37     }
38  
39     public String getType() {
40         return type;
41     }
42  
43     public void setType(String type) {
44         this.type = type;
45     }
46  
47     public String getProxyHostIp() {
48         return proxyHostIp;
49     }
50  
51     public void setProxyHostIp(String proxyHostIp) {
52         this.proxyHostIp = proxyHostIp;
53     }
54  
55     public Integer getProxyPort() {
56         return proxyPort;
57     }
58  
59     public void setProxyPort(Integer proxyPort) {
60         this.proxyPort = proxyPort;
61     }
62  
63     public SysNetProxyCfg(Integer id, Timestamp createDate,
64             Timestamp modifyDate, String name, String type, String proxyHostIp,
65             Integer proxyPort) {
66         super(id, createDate, modifyDate);
67         this.name = name;
68         this.type = type;
69         this.proxyHostIp = proxyHostIp;
70         this.proxyPort = proxyPort;
71     }
72  
73     public SysNetProxyCfg() {
74         super();
75     }
76 }

测试类:
01 package net.tjnwdseip.demo;
02  
03 import java.sql.Timestamp;
04 import java.util.HashMap;
05  
06 import net.tjnwdseip.entity.SysNetProxyCfg;
07 import net.tjnwdseip.util.CreateSqlTools;
08  
09 public class DemoTest {
10  
11      
12     public static void main(String[] args) {
13         // TODO Auto-generated method stub
14         SysNetProxyCfg netProxyCfg = new SysNetProxyCfg(1, Timestamp.valueOf("2012-05-04 14:45:35"), null, "netProxyCfgName", "netProxyCfgType", "000.000.000.000", 0);
15         HashMap<String, String> fixedParams=new HashMap<String,String>();
16         fixedParams.put("createDate", "NOW()");
17         fixedParams.put("modifyDate", "NOW()");
18         System.out.println(CreateSqlTools.getDeleteSql(netProxyCfg));
19         System.out.println(CreateSqlTools.getDeleteSql(netProxyCfg, true));
20         System.out.println(CreateSqlTools.getInsertSql(netProxyCfg));
21         System.out.println(CreateSqlTools.getInsertSql(netProxyCfg, fixedParams));
22         System.out.println(CreateSqlTools.getSelectAllSql(netProxyCfg));
23         System.out.println(CreateSqlTools.getUpdateSql(netProxyCfg));
24         System.out.println(CreateSqlTools.getUpdateSql(netProxyCfg, true));
25         System.out.println(CreateSqlTools.getUpdateSql(netProxyCfg, true, fixedParams));
26     }
27  
28 }

测试结果:
DELETE FROM sysNetProxyCfg WHERE id=1 AND createDate='2012-05-04 14:45:35.0' AND name='netProxyCfgName' AND type='netProxyCfgType' AND proxyHostIp='000.000.000.000' AND proxyPort=0  DELETE FROM sysNetProxyCfg WHERE id=1  INSERT INTO sysNetProxyCfg (createDate,name,type,proxyHostIp,proxyPort) VALUES ('2012-05-04 14:45:35.0','netProxyCfgName','netProxyCfgType','000.000.000.000',0)  INSERT INTO sysNetProxyCfg (modifyDate,createDate) VALUES (NOW(),NOW())  SELECT id,createDate,modifyDate,name,type,proxyHostIp,proxyPort FROM sysNetProxyCfg  UPDATE sysNetProxyCfg SET createDate='2012-05-04 14:45:35.0',name='netProxyCfgName',type='netProxyCfgType',proxyHostIp='000.000.000.000',proxyPort=0  UPDATE sysNetProxyCfg SET createDate='2012-05-04 14:45:35.0',name='netProxyCfgName',type='netProxyCfgType',proxyHostIp='000.000.000.000',proxyPort=0 WHERE id=1  UPDATE sysNetProxyCfg SET modifyDate=NOW(),createDate=NOW() WHERE id=1