refactor: 收敛模型参数服务对齐逻辑

This commit is contained in:
wkc
2026-03-16 11:03:19 +08:00
parent 7a3838d00a
commit 5739a7bac0
2 changed files with 193 additions and 182 deletions

View File

@@ -27,7 +27,7 @@ import org.springframework.transaction.annotation.Transactional;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.Date; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -49,25 +49,7 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
@Override @Override
public List<ModelListVO> selectModelList(Long projectId) { public List<ModelListVO> selectModelList(Long projectId) {
log.info("selectModelList 被调用projectId={}", projectId); log.info("selectModelList 被调用projectId={}", projectId);
Long effectiveProjectId = resolveEffectiveProjectId(projectId, false);
if (projectId == null) {
projectId = 0L; // 默认查询系统级参数
}
// 如果是项目查询projectId > 0需要根据 configType 决定查询哪组参数
Long effectiveProjectId = projectId;
if (projectId > 0) {
// 查询项目信息
CcdiProject project = projectMapper.selectById(projectId);
log.info("查询到项目信息: projectId={}, configType={}", projectId,
project != null ? project.getConfigType() : "null");
if (project != null && "default".equals(project.getConfigType())) {
// 使用系统默认参数
effectiveProjectId = 0L;
log.info("项目使用默认配置切换到系统默认参数effectiveProjectId=0");
}
}
log.info("准备查询模型列表effectiveProjectId={}", effectiveProjectId); log.info("准备查询模型列表effectiveProjectId={}", effectiveProjectId);
List<ModelListVO> result = new ArrayList<>(); List<ModelListVO> result = new ArrayList<>();
@@ -86,38 +68,12 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
@Override @Override
public List<ModelParamVO> selectParamList(ModelParamQueryDTO queryDTO) { public List<ModelParamVO> selectParamList(ModelParamQueryDTO queryDTO) {
// 1. 参数验证 Long effectiveProjectId = resolveEffectiveProjectId(queryDTO.getProjectId(), true);
Long projectId = queryDTO.getProjectId();
if (projectId == null) {
projectId = 0L;
}
// 2. 如果是项目查询projectId > 0需要根据 configType 决定查询哪组参数
Long effectiveProjectId = projectId;
if (projectId > 0) {
// 查询项目信息
CcdiProject project = projectMapper.selectById(projectId);
if (project == null) {
throw new ServiceException("项目不存在");
}
// 根据 configType 决定查询哪组参数
if ("default".equals(project.getConfigType())) {
// 使用系统默认参数
effectiveProjectId = 0L;
} else {
// 使用项目自定义参数
effectiveProjectId = projectId;
}
}
// 3. 查询参数列表
List<CcdiModelParam> params = modelParamMapper.selectByProjectAndModel( List<CcdiModelParam> params = modelParamMapper.selectByProjectAndModel(
effectiveProjectId, effectiveProjectId,
queryDTO.getModelCode() queryDTO.getModelCode()
); );
// 4. 转换为 VO
List<ModelParamVO> result = new ArrayList<>(); List<ModelParamVO> result = new ArrayList<>();
params.forEach(param -> { params.forEach(param -> {
ModelParamVO vo = new ModelParamVO(); ModelParamVO vo = new ModelParamVO();
@@ -145,31 +101,10 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
Long projectId = saveDTO.getProjectId(); Long projectId = saveDTO.getProjectId();
// 2. 如果是项目保存projectId > 0需要检查是否首次保存
if (projectId > 0) { if (projectId > 0) {
// 查询项目信息 switchToCustomConfigIfNeeded(getRequiredProject(projectId));
CcdiProject project = projectMapper.selectById(projectId);
if (project == null) {
throw new ServiceException("项目不存在");
} }
// 3. 如果是首次保存configType=default需要复制系统默认参数
if ("default".equals(project.getConfigType())) {
int copiedCount = copyDefaultParamsToProject(projectId, saveDTO.getModelCode());
if (copiedCount == 0) {
log.warn("系统默认参数为空projectId={}, modelCode={}",
projectId, saveDTO.getModelCode());
}
// 更新项目配置类型为 custom
project.setConfigType("custom");
projectMapper.updateById(project);
log.info("项目配置类型已更新为 customprojectId={}", projectId);
}
}
// 4. 更新参数值
String username = SecurityUtils.getUsername(); String username = SecurityUtils.getUsername();
for (ModelParamSaveDTO.ParamValueItem item : saveDTO.getParams()) { for (ModelParamSaveDTO.ParamValueItem item : saveDTO.getParams()) {
int updated = modelParamMapper.updateParamValue( int updated = modelParamMapper.updateParamValue(
@@ -194,74 +129,14 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
} }
} }
/**
* 复制系统默认参数到项目
*
* @param projectId 项目ID
* @param modelCode 模型编码
* @return 复制的参数数量
*/
private int copyDefaultParamsToProject(Long projectId, String modelCode) {
// 查询系统默认参数
List<CcdiModelParam> defaultParams = modelParamMapper.selectByProjectAndModel(0L, modelCode);
if (defaultParams.isEmpty()) {
log.warn("系统默认参数为空modelCode={}", modelCode);
return 0;
}
// 复制到项目
String username = SecurityUtils.getUsername();
List<CcdiModelParam> projectParams = defaultParams.stream()
.map(param -> {
CcdiModelParam newParam = new CcdiModelParam();
BeanUtils.copyProperties(param, newParam);
newParam.setId(null); // 清空ID让数据库自动生成
newParam.setProjectId(projectId);
// 设置审计字段
newParam.setCreateBy(username);
newParam.setUpdateBy(username);
// create_time 和 update_time 由数据库 NOW() 自动设置
return newParam;
})
.collect(Collectors.toList());
// 批量插入
int count = modelParamMapper.insertBatch(projectParams);
log.info("复制系统默认参数到项目成功projectId={}, modelCode={}, count={}",
projectId, modelCode, count);
return count;
}
@Override @Override
public ModelParamAllVO selectAllParams(Long projectId) { public ModelParamAllVO selectAllParams(Long projectId) {
// 1. 参数验证 Long effectiveProjectId = resolveEffectiveProjectId(projectId, true);
if (projectId == null) {
projectId = 0L;
}
// 2. 如果是项目查询,根据 configType 决定查询哪组参数
Long effectiveProjectId = projectId;
if (projectId > 0) {
CcdiProject project = projectMapper.selectById(projectId);
if (project == null) {
throw new ServiceException("项目不存在");
}
if ("default".equals(project.getConfigType())) {
effectiveProjectId = 0L;
}
}
// 3. 查询所有模型的参数
List<CcdiModelParam> allParams = modelParamMapper.selectByProjectId(effectiveProjectId); List<CcdiModelParam> allParams = modelParamMapper.selectByProjectId(effectiveProjectId);
// 4. 按模型分组
Map<String, List<CcdiModelParam>> groupedParams = allParams.stream() Map<String, List<CcdiModelParam>> groupedParams = allParams.stream()
.collect(Collectors.groupingBy(CcdiModelParam::getModelCode)); .collect(Collectors.groupingBy(CcdiModelParam::getModelCode, LinkedHashMap::new, Collectors.toList()));
// 5. 转换为VO
ModelParamAllVO result = new ModelParamAllVO(); ModelParamAllVO result = new ModelParamAllVO();
List<ModelGroupVO> models = new ArrayList<>(); List<ModelGroupVO> models = new ArrayList<>();
@@ -282,7 +157,6 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
models.add(groupVO); models.add(groupVO);
}); });
// 6. 按模型编码排序(保证固定顺序)
models.sort(Comparator.comparing(ModelGroupVO::getModelCode)); models.sort(Comparator.comparing(ModelGroupVO::getModelCode));
result.setModels(models); result.setModels(models);
@@ -303,63 +177,15 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
Long projectId = saveAllDTO.getProjectId(); Long projectId = saveAllDTO.getProjectId();
// 2. 如果是项目保存,检查是否需要复制默认参数
if (projectId > 0) { if (projectId > 0) {
CcdiProject project = projectMapper.selectById(projectId); switchToCustomConfigIfNeeded(getRequiredProject(projectId));
if (project == null) {
throw new ServiceException("项目不存在");
} }
// 如果是首次保存configType=default),需要复制所有模型的系统默认参数
if ("default".equals(project.getConfigType())) {
// 1. 查询所有系统默认参数(所有模型的所有参数)
List<CcdiModelParam> allDefaultParams = modelParamMapper.selectByProjectId(0L);
if (allDefaultParams.isEmpty()) {
log.warn("系统默认参数为空");
return;
}
// 2. 批量复制所有默认参数到项目
String username = SecurityUtils.getUsername();
List<CcdiModelParam> projectParams = new ArrayList<>();
for (CcdiModelParam param : allDefaultParams) {
CcdiModelParam newParam = new CcdiModelParam();
BeanUtils.copyProperties(param, newParam);
newParam.setId(null);
newParam.setProjectId(projectId);
// 设置审计字段
newParam.setCreateBy(username);
newParam.setUpdateBy(username);
// create_time 和 update_time 由数据库 NOW() 自动设置
projectParams.add(newParam);
}
// 3. 批量插入
modelParamMapper.insertBatch(projectParams);
log.info("复制所有系统默认参数到项目成功, projectId={}, count={}",
projectId, projectParams.size());
// 更新项目配置类型为 custom
project.setConfigType("custom");
projectMapper.updateById(project);
}
}
// 3. 批量更新所有模型的参数值(性能优化版本)
String username = SecurityUtils.getUsername(); String username = SecurityUtils.getUsername();
List<CcdiModelParam> updateList = new ArrayList<>(); List<CcdiModelParam> updateList = new ArrayList<>();
// 3.1 收集需要更新的参数
for (ModelParamGroupDTO modelGroup : saveAllDTO.getModels()) { for (ModelParamGroupDTO modelGroup : saveAllDTO.getModels()) {
for (ParamValueItem item : modelGroup.getParams()) { for (ParamValueItem item : modelGroup.getParams()) {
// 查询参数ID用于批量更新
CcdiModelParam queryParam = new CcdiModelParam();
queryParam.setProjectId(projectId);
queryParam.setModelCode(modelGroup.getModelCode());
queryParam.setParamCode(item.getParamCode());
// 使用 MyBatis Plus 查询
CcdiModelParam existingParam = modelParamMapper.selectOne( CcdiModelParam existingParam = modelParamMapper.selectOne(
new LambdaQueryWrapper<CcdiModelParam>() new LambdaQueryWrapper<CcdiModelParam>()
.eq(CcdiModelParam::getProjectId, projectId) .eq(CcdiModelParam::getProjectId, projectId)
@@ -378,7 +204,6 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
} }
} }
// 3.2 批量更新(一次 SQL 执行)
if (!updateList.isEmpty()) { if (!updateList.isEmpty()) {
modelParamMapper.batchUpdateParamValues(updateList); modelParamMapper.batchUpdateParamValues(updateList);
log.info("批量更新参数成功, count={}", updateList.size()); log.info("批量更新参数成功, count={}", updateList.size());
@@ -390,4 +215,73 @@ public class CcdiModelParamServiceImpl implements ICcdiModelParamService {
throw new ServiceException("批量保存模型参数失败:" + e.getMessage()); throw new ServiceException("批量保存模型参数失败:" + e.getMessage());
} }
} }
private Long resolveEffectiveProjectId(Long projectId, boolean failWhenProjectMissing) {
if (projectId == null || projectId <= 0) {
return 0L;
}
CcdiProject project = projectMapper.selectById(projectId);
log.info("查询到项目信息: projectId={}, configType={}", projectId,
project != null ? project.getConfigType() : "null");
if (project == null) {
if (failWhenProjectMissing) {
throw new ServiceException("项目不存在");
}
return projectId;
}
return "default".equals(project.getConfigType()) ? 0L : projectId;
}
private CcdiProject getRequiredProject(Long projectId) {
CcdiProject project = projectMapper.selectById(projectId);
if (project == null) {
throw new ServiceException("项目不存在");
}
return project;
}
private void switchToCustomConfigIfNeeded(CcdiProject project) {
if (!"default".equals(project.getConfigType())) {
return;
}
int copiedCount = copyAllDefaultParamsToProject(project.getProjectId());
if (copiedCount == 0) {
log.warn("系统默认参数为空projectId={}", project.getProjectId());
return;
}
project.setConfigType("custom");
projectMapper.updateById(project);
log.info("项目配置类型已更新为 customprojectId={}", project.getProjectId());
}
private int copyAllDefaultParamsToProject(Long projectId) {
List<CcdiModelParam> defaultParams = modelParamMapper.selectByProjectId(0L);
if (defaultParams.isEmpty()) {
return 0;
}
String username = SecurityUtils.getUsername();
List<CcdiModelParam> projectParams = defaultParams.stream()
.map(param -> buildProjectParam(param, projectId, username))
.collect(Collectors.toList());
int count = modelParamMapper.insertBatch(projectParams);
log.info("复制所有系统默认参数到项目成功projectId={}, count={}", projectId, count);
return count;
}
private CcdiModelParam buildProjectParam(CcdiModelParam source, Long projectId, String username) {
CcdiModelParam target = new CcdiModelParam();
BeanUtils.copyProperties(source, target);
target.setId(null);
target.setProjectId(projectId);
target.setCreateBy(username);
target.setUpdateBy(username);
return target;
}
} }

View File

@@ -0,0 +1,117 @@
package com.ruoyi.ccdi.project.service.impl;
import com.ruoyi.ccdi.project.domain.CcdiModelParam;
import com.ruoyi.ccdi.project.domain.CcdiProject;
import com.ruoyi.ccdi.project.domain.dto.ModelParamSaveDTO;
import com.ruoyi.ccdi.project.domain.vo.ModelParamAllVO;
import com.ruoyi.ccdi.project.mapper.CcdiModelParamMapper;
import com.ruoyi.ccdi.project.mapper.CcdiProjectMapper;
import com.ruoyi.common.utils.SecurityUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class CcdiModelParamServiceImplTest {
@InjectMocks
private CcdiModelParamServiceImpl service;
@Mock
private CcdiModelParamMapper modelParamMapper;
@Mock
private CcdiProjectMapper projectMapper;
@Test
void selectAllParams_shouldReadSystemDefaultsForDefaultProject() {
CcdiProject project = new CcdiProject();
project.setProjectId(100L);
project.setConfigType("default");
when(projectMapper.selectById(100L)).thenReturn(project);
when(modelParamMapper.selectByProjectId(0L)).thenReturn(List.of(
buildParam(2L, 0L, "SUSPICIOUS_PART_TIME", "模型B", "P2", "2"),
buildParam(1L, 0L, "LARGE_TRANSACTION", "模型A", "P1", "1")
));
ModelParamAllVO result = service.selectAllParams(100L);
verify(modelParamMapper).selectByProjectId(0L);
assertEquals(2, result.getModels().size());
assertEquals("LARGE_TRANSACTION", result.getModels().get(0).getModelCode());
assertEquals("SUSPICIOUS_PART_TIME", result.getModels().get(1).getModelCode());
}
@Test
@SuppressWarnings("unchecked")
void saveParams_shouldCopyAllSystemDefaultsForDefaultProjectOnFirstSave() {
CcdiProject project = new CcdiProject();
project.setProjectId(123L);
project.setConfigType("default");
when(projectMapper.selectById(123L)).thenReturn(project);
when(modelParamMapper.selectByProjectId(0L)).thenReturn(List.of(
buildParam(1L, 0L, "LARGE_TRANSACTION", "大额交易模型", "SINGLE_TRANSACTION_AMOUNT", "1111"),
buildParam(2L, 0L, "SUSPICIOUS_GAMBLING", "疑似赌博交易模型", "multi_party_amt_min", "500")
));
when(modelParamMapper.insertBatch(anyList())).thenReturn(2);
when(modelParamMapper.updateParamValue(123L, "LARGE_TRANSACTION", "SINGLE_TRANSACTION_AMOUNT", "2222", "admin"))
.thenReturn(1);
ModelParamSaveDTO saveDTO = new ModelParamSaveDTO();
saveDTO.setProjectId(123L);
saveDTO.setModelCode("LARGE_TRANSACTION");
ModelParamSaveDTO.ParamValueItem item = new ModelParamSaveDTO.ParamValueItem();
item.setParamCode("SINGLE_TRANSACTION_AMOUNT");
item.setParamValue("2222");
saveDTO.setParams(List.of(item));
try (MockedStatic<SecurityUtils> mocked = mockStatic(SecurityUtils.class)) {
mocked.when(SecurityUtils::getUsername).thenReturn("admin");
service.saveParams(saveDTO);
}
ArgumentCaptor<List<CcdiModelParam>> captor = ArgumentCaptor.forClass(List.class);
verify(modelParamMapper).insertBatch(captor.capture());
List<CcdiModelParam> copiedParams = captor.getValue();
assertEquals(2, copiedParams.size());
assertTrue(copiedParams.stream().allMatch(param -> Long.valueOf(123L).equals(param.getProjectId())));
assertEquals("custom", project.getConfigType());
verify(projectMapper).updateById(project);
verify(modelParamMapper).updateParamValue(123L, "LARGE_TRANSACTION", "SINGLE_TRANSACTION_AMOUNT", "2222", "admin");
}
private CcdiModelParam buildParam(
Long id,
Long projectId,
String modelCode,
String modelName,
String paramCode,
String paramValue
) {
CcdiModelParam param = new CcdiModelParam();
param.setId(id);
param.setProjectId(projectId);
param.setModelCode(modelCode);
param.setModelName(modelName);
param.setParamCode(paramCode);
param.setParamName(paramCode);
param.setParamValue(paramValue);
param.setSortOrder(1);
return param;
}
}